39 os << std::get<0>(r) <<
':' << std::get<1>(r);
52 auto ll = std::get<0>(l), lu = std::get<1>(l);
53 auto rl = std::get<0>(r), ru = std::get<1>(r);
54 return ll.ult(rl) || (ll == rl && lu.ult(ru));
69 return llvm::hash_value(c.
getIndex());
72 return llvm::hash_value(std::get<0>(r)) ^ llvm::hash_value(std::get<1>(r));
91 mlir::succeeded(sDef),
95 return std::move(sDef.value());
98std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
99 mlir::SymbolTableCollection &tables, mlir::ModuleOp
mod, ArrayType arrayTy,
100 mlir::BlockArgument blockArg, std::vector<ConstrainRefIndex> fields = {}
102 std::vector<ConstrainRef> res;
104 res.emplace_back(blockArg, fields);
107 int64_t maxSz = arrayTy.getDimSize(0);
108 for (int64_t i = 0; i < maxSz; i++) {
109 auto elemTy = arrayTy.getElementType();
111 std::vector<ConstrainRefIndex> subFields = fields;
112 subFields.emplace_back(i);
114 if (
auto arrayElemTy = mlir::dyn_cast<ArrayType>(elemTy)) {
116 auto subRes = getAllConstrainRefs(tables,
mod, arrayElemTy, blockArg, subFields);
117 res.insert(res.end(), subRes.begin(), subRes.end());
118 }
else if (
auto structTy = mlir::dyn_cast<StructType>(elemTy)) {
120 auto subRes = getAllConstrainRefs(
123 res.insert(res.end(), subRes.begin(), subRes.end());
126 res.emplace_back(blockArg, subFields);
133std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
134 mlir::SymbolTableCollection &tables, mlir::ModuleOp
mod,
135 SymbolLookupResult<StructDefOp> structDefRes, mlir::BlockArgument blockArg,
136 std::vector<ConstrainRefIndex> fields = {}
138 std::vector<ConstrainRef> res;
140 res.emplace_back(blockArg, fields);
142 for (
auto f : structDefRes.get().getOps<FieldDefOp>()) {
143 std::vector<ConstrainRefIndex> subFields = fields;
149 auto structDefCopy = structDefRes;
151 tables, mlir::SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()),
152 std::move(structDefCopy),
mod.getOperation()
154 ensure(mlir::succeeded(fieldLookup),
"could not get SymbolLookupResult of existing FieldDefOp");
155 subFields.emplace_back(fieldLookup.value());
158 res.emplace_back(blockArg, subFields);
159 if (
auto structTy = mlir::dyn_cast<StructType>(f.getType())) {
161 auto subRes = getAllConstrainRefs(
164 res.insert(res.end(), subRes.begin(), subRes.end());
165 }
else if (
auto arrayTy = mlir::dyn_cast<ArrayType>(f.getType())) {
167 auto subRes = getAllConstrainRefs(tables,
mod, arrayTy, blockArg, subFields);
168 res.insert(res.end(), subRes.begin(), subRes.end());
174std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
175 mlir::SymbolTableCollection &tables, mlir::ModuleOp
mod, mlir::BlockArgument arg,
176 std::vector<ConstrainRefIndex> fields
178 ConstrainRef root(arg, fields);
179 auto ty = root.getType();
180 std::vector<ConstrainRef> res;
181 if (
auto structTy = mlir::dyn_cast<StructType>(ty)) {
183 res = getAllConstrainRefs(tables,
mod,
getStructDef(tables,
mod, structTy), arg, fields);
184 }
else if (
auto arrayType = mlir::dyn_cast<ArrayType>(ty)) {
185 res = getAllConstrainRefs(tables,
mod, arrayType, arg, fields);
186 }
else if (mlir::isa<FeltType, IndexType, StringType>(ty)) {
188 res.emplace_back(root);
191 debug::Appender(err) <<
"unsupported type: " << ty;
192 llvm::report_fatal_error(mlir::Twine(err));
197std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
StructDefOp structDef) {
198 std::vector<ConstrainRef> res;
204 mlir::succeeded(modOp),
205 "could not lookup module from struct " + mlir::Twine(structDef.getName())
208 mlir::SymbolTableCollection tables;
209 for (
auto a : constrainFnOp.getArguments()) {
210 auto argRes = getAllConstrainRefs(tables, modOp.value(), a);
211 res.insert(res.end(), argRes.begin(), argRes.end());
216std::vector<ConstrainRef>
218 std::vector<ConstrainRef> res;
221 fieldDef->getParentOfType<
StructDefOp>() == structDef,
222 "Field " + mlir::Twine(fieldDef.getName()) +
" is not a field of struct " +
223 mlir::Twine(structDef.getName())
227 mlir::succeeded(modOp),
228 "could not lookup module from struct " + mlir::Twine(structDef.getName())
232 BlockArgument self = constrainFnOp.
getBody().getArgument(0);
234 mlir::SymbolTableCollection tables;
235 return getAllConstrainRefs(tables, modOp.value(), self, {ConstrainRefIndex(fieldDef)});
240 return std::get<FeltConstantOp>(*constantVal).getType();
242 return std::get<mlir::arith::ConstantIndexOp>(*constantVal).getType();
244 return std::get<ConstReadOp>(*constantVal).getType();
246 int array_derefs = 0;
247 int idx = fieldRefs.size() - 1;
248 while (idx >= 0 && fieldRefs[idx].
isIndex()) {
254 mlir::Type currTy = fieldRefs[idx].getField().getType();
255 while (array_derefs > 0) {
256 currTy = mlir::dyn_cast<ArrayType>(currTy).getElementType();
261 return blockArg.getType();
271 if (blockArg != prefix.blockArg || fieldRefs.size() < prefix.fieldRefs.size()) {
274 for (
size_t i = 0; i < prefix.fieldRefs.size(); i++) {
275 if (fieldRefs[i] != prefix.fieldRefs[i]) {
285 return mlir::failure();
287 std::vector<ConstrainRefIndex> suffix;
288 for (
size_t i = prefix.fieldRefs.size(); i < fieldRefs.size(); i++) {
289 suffix.push_back(fieldRefs[i]);
294mlir::FailureOr<ConstrainRef>
300 if (mlir::failed(suffix)) {
301 return mlir::failure();
304 auto newSignalUsage = other;
305 newSignalUsage.fieldRefs.insert(newSignalUsage.fieldRefs.end(), suffix->begin(), suffix->end());
306 return newSignalUsage;
315 auto constRead = std::get<ConstReadOp>(*constantVal);
316 auto structDefOp = constRead->getParentOfType<
StructDefOp>();
317 ensure(structDefOp,
"struct template should have a struct parent");
318 os <<
'@' << structDefOp.getName() <<
"<[@" << constRead.getConstName() <<
"]>";
322 for (
auto f : fieldRefs) {
323 os <<
"[" << f <<
"]";
329 return (blockArg == rhs.blockArg) && (fieldRefs == rhs.fieldRefs) &&
330 (constantVal == rhs.constantVal);
343 auto bitWidthMax = std::max(lhsInt.getBitWidth(), rhsInt.getBitWidth());
344 return lhsInt.zext(bitWidthMax).ult(rhsInt.zext(bitWidthMax));
362 auto lhsName = std::get<ConstReadOp>(*constantVal).getConstName();
363 auto rhsName = std::get<ConstReadOp>(*rhs.constantVal).getConstName();
364 return lhsName.compare(rhsName) < 0;
375 for (
size_t i = 0; i < fieldRefs.size() && i < rhs.fieldRefs.size(); i++) {
376 if (fieldRefs[i] < rhs.fieldRefs[i]) {
378 }
else if (fieldRefs[i] > rhs.fieldRefs[i]) {
382 return fieldRefs.size() < rhs.fieldRefs.size();
390 }(std::get<mlir::arith::ConstantIndexOp>(*val.constantVal));
396 size_t hash = std::hash<unsigned> {}(val.
getInputNum());
397 for (
auto f : val.fieldRefs) {
412 insert(rhs.begin(), rhs.end());
418 std::vector<ConstrainRef> sortedRefs(rhs.begin(), rhs.end());
419 std::sort(sortedRefs.begin(), sortedRefs.end());
420 for (
auto it = sortedRefs.begin(); it != sortedRefs.end();) {
423 if (it != sortedRefs.end()) {
This file defines methods symbol lookup across LLZK operations and included files.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
bool operator<(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(component::FieldDefOp f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool operator<(const ConstrainRef &rhs) const
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isTemplateConstant() const
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
ConstrainRef(mlir::BlockArgument b)
bool operator==(const ConstrainRef &rhs) const
bool isBlockArgument() const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::Type getType() const
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
unsigned getInputNum() const
static constexpr ::llvm::StringLiteral getOperationName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
::mlir::Region & getBody()
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
FailureOr< ModuleOp > getRootModule(Operation *from)
void ensure(bool condition, llvm::Twine errMsg)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
SymbolLookupResult< StructDefOp > getStructDef(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
size_t operator()(const ConstrainRefIndex &c) const
size_t operator()(const ConstrainRef &val) const