LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRef.h
Go to the documentation of this file.
1//===-- ConstrainRef.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
19#include "llzk/Util/Hash.h"
20
21#include <mlir/Analysis/DataFlowFramework.h>
22#include <mlir/Dialect/Arith/IR/Arith.h>
23#include <mlir/Pass/AnalysisManager.h>
24
25#include <llvm/ADT/EquivalenceClasses.h>
26
27#include <unordered_set>
28#include <vector>
29
30namespace llzk {
31
36 using IndexRange = std::pair<mlir::APInt, mlir::APInt>;
37
38public:
39 explicit ConstrainRefIndex(component::FieldDefOp f) : index(f) {}
41 explicit ConstrainRefIndex(mlir::APInt i) : index(i) {}
42 explicit ConstrainRefIndex(int64_t i) : index(toAPInt(i)) {}
43 ConstrainRefIndex(mlir::APInt low, mlir::APInt high) : index(IndexRange {low, high}) {}
44 explicit ConstrainRefIndex(IndexRange r) : index(r) {}
45
46 bool isField() const {
47 return std::holds_alternative<SymbolLookupResult<component::FieldDefOp>>(index) ||
48 std::holds_alternative<component::FieldDefOp>(index);
49 }
51 ensure(isField(), "ConstrainRefIndex: field requested but not contained");
52 if (std::holds_alternative<component::FieldDefOp>(index)) {
53 return std::get<component::FieldDefOp>(index);
54 }
55 return std::get<SymbolLookupResult<component::FieldDefOp>>(index).get();
56 }
57
58 bool isIndex() const { return std::holds_alternative<mlir::APInt>(index); }
59 mlir::APInt getIndex() const {
60 ensure(isIndex(), "ConstrainRefIndex: index requested but not contained");
61 return std::get<mlir::APInt>(index);
62 }
63
64 bool isIndexRange() const { return std::holds_alternative<IndexRange>(index); }
65 IndexRange getIndexRange() const {
66 ensure(isIndexRange(), "ConstrainRefIndex: index range requested but not contained");
67 return std::get<IndexRange>(index);
68 }
69
70 inline void dump() const { print(llvm::errs()); }
71 void print(mlir::raw_ostream &os) const;
72
73 inline bool operator==(const ConstrainRefIndex &rhs) const {
74 if (isField() && rhs.isField()) {
75 // We compare the underlying fields, since the field could be in a symbol
76 // lookup or not.
77 return getField() == rhs.getField();
78 }
79 if (isIndex() && rhs.isIndex()) {
80 return safeEq(mlir::APSInt(getIndex()), mlir::APSInt(rhs.getIndex()));
81 }
82 return index == rhs.index;
83 }
84
85 bool operator<(const ConstrainRefIndex &rhs) const;
86
87 bool operator>(const ConstrainRefIndex &rhs) const { return rhs < *this; }
88
89 struct Hash {
90 size_t operator()(const ConstrainRefIndex &c) const;
91 };
92
93 inline size_t getHash() const { return Hash {}(*this); }
94
95private:
102 std::variant<
104 index;
105};
106
107static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefIndex &rhs) {
108 rhs.print(os);
109 return os;
110}
111
122
123public:
125 static std::vector<ConstrainRef>
126 getAllConstrainRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ConstrainRef root);
127
129 static std::vector<ConstrainRef>
131
134 static std::vector<ConstrainRef>
136
137 explicit ConstrainRef(mlir::BlockArgument b) : root(b), fieldRefs(), constantVal(std::nullopt) {}
138 ConstrainRef(mlir::BlockArgument b, std::vector<ConstrainRefIndex> f)
139 : root(b), fieldRefs(std::move(f)), constantVal(std::nullopt) {}
140
142 : root(createOp), fieldRefs(), constantVal(std::nullopt) {}
143 ConstrainRef(component::CreateStructOp createOp, std::vector<ConstrainRefIndex> f)
144 : root(createOp), fieldRefs(std::move(f)), constantVal(std::nullopt) {}
145
146 explicit ConstrainRef(felt::FeltConstantOp c) : root(std::nullopt), fieldRefs(), constantVal(c) {}
147 explicit ConstrainRef(mlir::arith::ConstantIndexOp c)
148 : root(std::nullopt), fieldRefs(), constantVal(c) {}
150 : root(std::nullopt), fieldRefs(), constantVal(c) {}
151
152 mlir::Type getType() const;
153
154 bool isConstantFelt() const {
155 return constantVal.has_value() && std::holds_alternative<felt::FeltConstantOp>(*constantVal);
156 }
157 bool isConstantIndex() const {
158 return constantVal.has_value() &&
159 std::holds_alternative<mlir::arith::ConstantIndexOp>(*constantVal);
160 }
161 bool isTemplateConstant() const {
162 return constantVal.has_value() &&
163 std::holds_alternative<polymorphic::ConstReadOp>(*constantVal);
164 }
165 bool isConstant() const { return constantVal.has_value(); }
166 bool isConstantInt() const { return isConstantFelt() || isConstantIndex(); }
167
168 bool isFeltVal() const { return mlir::isa<felt::FeltType>(getType()); }
169 bool isIndexVal() const { return mlir::isa<mlir::IndexType>(getType()); }
170 bool isIntegerVal() const { return mlir::isa<mlir::IntegerType>(getType()); }
171 bool isTypeVarVal() const { return mlir::isa<polymorphic::TypeVarType>(getType()); }
172 bool isScalar() const {
173 return isConstant() || isFeltVal() || isIndexVal() || isIntegerVal() || isTypeVarVal();
174 }
175 bool isSignal() const { return isSignalType(getType()); }
176
177 bool isBlockArgument() const {
178 return root.has_value() && std::holds_alternative<mlir::BlockArgument>(*root);
179 }
180 mlir::BlockArgument getBlockArgument() const {
181 ensure(isBlockArgument(), "is not a block argument");
182 return std::get<mlir::BlockArgument>(*root);
183 }
184 unsigned getInputNum() const { return getBlockArgument().getArgNumber(); }
185
186 bool isCreateStructOp() const {
187 return root.has_value() && std::holds_alternative<component::CreateStructOp>(*root);
188 }
190 ensure(isCreateStructOp(), "is not a create struct op");
191 return std::get<component::CreateStructOp>(*root);
192 }
193
194 mlir::APInt getConstantFeltValue() const {
195 ensure(isConstantFelt(), __FUNCTION__ + mlir::Twine(" requires a constant felt!"));
196 return std::get<felt::FeltConstantOp>(*constantVal).getValueAttr().getValue();
197 }
198 mlir::APInt getConstantIndexValue() const {
199 ensure(isConstantIndex(), __FUNCTION__ + mlir::Twine(" requires a constant index!"));
200 return toAPInt(std::get<mlir::arith::ConstantIndexOp>(*constantVal).value());
201 }
202 mlir::APInt getConstantValue() const {
203 ensure(
205 __FUNCTION__ + mlir::Twine(" requires a constant int type!")
206 );
208 }
209
211 bool isValidPrefix(const ConstrainRef &prefix) const;
212
217 mlir::FailureOr<std::vector<ConstrainRefIndex>> getSuffix(const ConstrainRef &prefix) const;
218
225 mlir::FailureOr<ConstrainRef>
226 translate(const ConstrainRef &prefix, const ConstrainRef &other) const;
227
229 mlir::FailureOr<ConstrainRef> getParentPrefix() const {
230 if (isConstantFelt() || fieldRefs.empty()) {
231 return mlir::failure();
232 }
233 auto copy = *this;
234 copy.fieldRefs.pop_back();
235 return copy;
236 }
237
239 std::vector<ConstrainRef>
240 getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const;
241
243 auto copy = *this;
244 copy.fieldRefs.push_back(r);
245 return copy;
246 }
247
249 assert(other.isConstantIndex());
251 }
252
253 const std::vector<ConstrainRefIndex> &getPieces() const { return fieldRefs; }
254
255 void print(mlir::raw_ostream &os) const;
256 void dump() const { print(llvm::errs()); }
257
258 bool operator==(const ConstrainRef &rhs) const;
259
260 bool operator!=(const ConstrainRef &rhs) const { return !(*this == rhs); }
261
262 // required for EquivalenceClasses usage
263 bool operator<(const ConstrainRef &rhs) const;
264
265 bool operator>(const ConstrainRef &rhs) const { return rhs < *this; }
266
267 struct Hash {
268 size_t operator()(const ConstrainRef &val) const;
269 };
270
271private:
281 std::optional<std::variant<mlir::BlockArgument, component::CreateStructOp>> root;
282
283 std::vector<ConstrainRefIndex> fieldRefs;
284 // using mutable to reduce constant casts for certain get* functions.
285 mutable std::optional<
286 std::variant<felt::FeltConstantOp, mlir::arith::ConstantIndexOp, polymorphic::ConstReadOp>>
287 constantVal;
288};
289
290mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs);
291
292/* ConstrainRefSet */
293
294class ConstrainRefSet : public std::unordered_set<ConstrainRef, ConstrainRef::Hash> {
295 using Base = std::unordered_set<ConstrainRef, ConstrainRef::Hash>;
296
297public:
298 using Base::Base;
299
301
302 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs);
303};
304
305static_assert(
307 "ConstrainRefSet must satisfy the ScalarLatticeValue requirements"
308);
309
310} // namespace llzk
311
312namespace llvm {
313
314template <> struct DenseMapInfo<llzk::ConstrainRef> {
316 return llzk::ConstrainRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
317 }
319 return llzk::ConstrainRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
320 }
321 static unsigned getHashValue(const llzk::ConstrainRef &ref) {
322 if (ref == getEmptyKey() || ref == getTombstoneKey()) {
323 return llvm::hash_value(ref.getBlockArgument().getAsOpaquePointer());
324 }
325 return llzk::ConstrainRef::Hash {}(ref);
326 }
327 static bool isEqual(const llzk::ConstrainRef &lhs, const llzk::ConstrainRef &rhs) {
328 return lhs == rhs;
329 }
330};
331
332} // namespace llvm
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
Defines an index into an LLZK object.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
ConstrainRefIndex(mlir::APInt i)
size_t getHash() const
bool operator<(const ConstrainRefIndex &rhs) const
bool operator==(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(SymbolLookupResult< component::FieldDefOp > f)
ConstrainRefIndex(component::FieldDefOp f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefIndex(IndexRange r)
ConstrainRefIndex(mlir::APInt low, mlir::APInt high)
bool operator>(const ConstrainRefIndex &rhs) const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool isTypeVarVal() const
const std::vector< ConstrainRefIndex > & getPieces() const
bool operator<(const ConstrainRef &rhs) const
component::CreateStructOp getCreateStructOp() const
ConstrainRef(felt::FeltConstantOp c)
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
bool isIntegerVal() const
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isConstant() const
ConstrainRef createChild(ConstrainRef other) const
ConstrainRef(polymorphic::ConstReadOp c)
bool isTemplateConstant() const
ConstrainRef(mlir::BlockArgument b, std::vector< ConstrainRefIndex > f)
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
mlir::Type getType() const
std::vector< ConstrainRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this ConstrainRef, assuming this ref is not a scalar.
ConstrainRef(mlir::BlockArgument b)
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool operator==(const ConstrainRef &rhs) const
mlir::FailureOr< ConstrainRef > getParentPrefix() const
Create a new reference that is the immediate prefix of this reference if possible.
ConstrainRef(mlir::arith::ConstantIndexOp c)
bool isBlockArgument() const
ConstrainRef(component::CreateStructOp createOp, std::vector< ConstrainRefIndex > f)
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::BlockArgument getBlockArgument() const
ConstrainRef createChild(ConstrainRefIndex r) const
bool isSignal() const
bool isIndexVal() const
bool isFeltVal() const
ConstrainRef(component::CreateStructOp createOp)
bool operator!=(const ConstrainRef &rhs) const
bool isCreateStructOp() const
static std::vector< ConstrainRef > getAllConstrainRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ConstrainRef root)
Produce all possible ConstraintRefs that are present starting from the given root.
void dump() const
bool isConstantInt() const
mlir::APInt getConstantValue() const
bool operator>(const ConstrainRef &rhs) const
bool isScalar() const
unsigned getInputNum() const
llvm::APInt toAPInt(int64_t i)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:74
raw_ostream & operator<<(raw_ostream &os, const ConstrainRef &rhs)
bool isSignalType(Type type)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static bool isEqual(const llzk::ConstrainRef &lhs, const llzk::ConstrainRef &rhs)
static llzk::ConstrainRef getTombstoneKey()
static llzk::ConstrainRef getEmptyKey()
static unsigned getHashValue(const llzk::ConstrainRef &ref)
size_t operator()(const ConstrainRefIndex &c) const
size_t operator()(const ConstrainRef &val) const