38 os << std::get<0>(r) <<
':' << std::get<1>(r);
51 auto ll = std::get<0>(l), lu = std::get<1>(l);
52 auto rl = std::get<0>(r), ru = std::get<1>(r);
53 return ll.ult(rl) || (ll == rl && lu.ult(ru));
79 mlir::succeeded(sDef),
83 return std::move(sDef.value());
86std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
87 mlir::SymbolTableCollection &tables, mlir::ModuleOp
mod, ArrayType arrayTy,
88 mlir::BlockArgument blockArg, std::vector<ConstrainRefIndex> fields = {}
90 std::vector<ConstrainRef> res;
92 res.emplace_back(blockArg, fields);
95 int64_t maxSz = arrayTy.getDimSize(0);
96 for (int64_t i = 0; i < maxSz; i++) {
97 auto elemTy = arrayTy.getElementType();
99 std::vector<ConstrainRefIndex> subFields = fields;
100 subFields.emplace_back(i);
102 if (
auto arrayElemTy = mlir::dyn_cast<ArrayType>(elemTy)) {
104 auto subRes = getAllConstrainRefs(tables,
mod, arrayElemTy, blockArg, subFields);
105 res.insert(res.end(), subRes.begin(), subRes.end());
106 }
else if (
auto structTy = mlir::dyn_cast<StructType>(elemTy)) {
108 auto subRes = getAllConstrainRefs(
111 res.insert(res.end(), subRes.begin(), subRes.end());
114 res.emplace_back(blockArg, subFields);
121std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
122 mlir::SymbolTableCollection &tables, mlir::ModuleOp
mod,
124 std::vector<ConstrainRefIndex> fields = {}
126 std::vector<ConstrainRef> res;
128 res.emplace_back(blockArg, fields);
130 for (
auto f : structDefRes.get().getOps<
FieldDefOp>()) {
131 std::vector<ConstrainRefIndex> subFields = fields;
137 auto structDefCopy = structDefRes;
139 tables, mlir::SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()),
140 std::move(structDefCopy),
mod.getOperation()
142 ensure(mlir::succeeded(fieldLookup),
"could not get SymbolLookupResult of existing FieldDefOp");
143 subFields.emplace_back(fieldLookup.value());
146 res.emplace_back(blockArg, subFields);
147 if (
auto structTy = mlir::dyn_cast<StructType>(f.getType())) {
149 auto subRes = getAllConstrainRefs(
152 res.insert(res.end(), subRes.begin(), subRes.end());
153 }
else if (
auto arrayTy = mlir::dyn_cast<ArrayType>(f.getType())) {
155 auto subRes = getAllConstrainRefs(tables,
mod, arrayTy, blockArg, subFields);
156 res.insert(res.end(), subRes.begin(), subRes.end());
162std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
163 mlir::SymbolTableCollection &tables, mlir::ModuleOp
mod, mlir::BlockArgument arg
165 auto ty = arg.getType();
166 std::vector<ConstrainRef> res;
167 if (
auto structTy = mlir::dyn_cast<StructType>(ty)) {
170 }
else if (
auto arrayType = mlir::dyn_cast<ArrayType>(ty)) {
171 res = getAllConstrainRefs(tables,
mod, arrayType, arg);
172 }
else if (mlir::isa<FeltType, IndexType, StringType>(ty)) {
174 res.emplace_back(arg);
177 debug::Appender(err) <<
"unsupported type: " << ty;
178 llvm::report_fatal_error(mlir::Twine(err));
183std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
StructDefOp structDef) {
184 std::vector<ConstrainRef> res;
188 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
193 mlir::succeeded(modOp),
194 "could not lookup module from struct " + mlir::Twine(structDef.getName())
197 mlir::SymbolTableCollection tables;
198 for (
auto a : constrainFnOp.getArguments()) {
199 auto argRes = getAllConstrainRefs(tables, modOp.value(), a);
200 res.insert(res.end(), argRes.begin(), argRes.end());
207 return std::get<FeltConstantOp>(*constantVal).getType();
209 return std::get<mlir::arith::ConstantIndexOp>(*constantVal).getType();
211 return std::get<ConstReadOp>(*constantVal).getType();
213 int array_derefs = 0;
214 int idx = fieldRefs.size() - 1;
215 while (idx >= 0 && fieldRefs[idx].isIndex()) {
221 mlir::Type currTy = fieldRefs[idx].getField().getType();
222 while (array_derefs > 0) {
223 currTy = mlir::dyn_cast<ArrayType>(currTy).getElementType();
228 return blockArg.getType();
238 if (blockArg != prefix.blockArg || fieldRefs.size() < prefix.fieldRefs.size()) {
241 for (
size_t i = 0; i < prefix.fieldRefs.size(); i++) {
242 if (fieldRefs[i] != prefix.fieldRefs[i]) {
252 return mlir::failure();
254 std::vector<ConstrainRefIndex> suffix;
255 for (
size_t i = prefix.fieldRefs.size(); i < fieldRefs.size(); i++) {
256 suffix.push_back(fieldRefs[i]);
261mlir::FailureOr<ConstrainRef>
267 if (mlir::failed(suffix)) {
268 return mlir::failure();
271 auto newSignalUsage = other;
272 newSignalUsage.fieldRefs.insert(newSignalUsage.fieldRefs.end(), suffix->begin(), suffix->end());
273 return newSignalUsage;
282 auto constRead = std::get<ConstReadOp>(*constantVal);
283 auto structDefOp = constRead->getParentOfType<
StructDefOp>();
284 ensure(structDefOp,
"struct template should have a struct parent");
285 os <<
'@' << structDefOp.getName() <<
"<[@" << constRead.getConstName() <<
"]>";
289 for (
auto f : fieldRefs) {
290 os <<
"[" << f <<
"]";
296 return (blockArg == rhs.blockArg) && (fieldRefs == rhs.fieldRefs) &&
297 (constantVal == rhs.constantVal);
310 auto bitWidthMax = std::max(lhsInt.getBitWidth(), rhsInt.getBitWidth());
311 return lhsInt.zext(bitWidthMax).ult(rhsInt.zext(bitWidthMax));
329 auto lhsName = std::get<ConstReadOp>(*constantVal).getConstName();
330 auto rhsName = std::get<ConstReadOp>(*rhs.constantVal).getConstName();
331 return lhsName.compare(rhsName) < 0;
342 for (
size_t i = 0; i < fieldRefs.size() && i < rhs.fieldRefs.size(); i++) {
343 if (fieldRefs[i] < rhs.fieldRefs[i]) {
345 }
else if (fieldRefs[i] > rhs.fieldRefs[i]) {
349 return fieldRefs.size() < rhs.fieldRefs.size();
357 }(std::get<mlir::arith::ConstantIndexOp>(*val.constantVal));
363 size_t hash = std::hash<unsigned> {}(val.
getInputNum());
364 for (
auto f : val.fieldRefs) {
379 insert(rhs.begin(), rhs.end());
385 std::vector<ConstrainRef> sortedRefs(rhs.begin(), rhs.end());
386 std::sort(sortedRefs.begin(), sortedRefs.end());
387 for (
auto it = sortedRefs.begin(); it != sortedRefs.end();) {
390 if (it != sortedRefs.end()) {
This file defines methods symbol lookup across LLZK operations and included files.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
bool operator<(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(SymbolLookupResult< component::FieldDefOp > f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool operator<(const ConstrainRef &rhs) const
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isTemplateConstant() const
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
ConstrainRef(mlir::BlockArgument b)
bool operator==(const ConstrainRef &rhs) const
bool isBlockArgument() const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::Type getType() const
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
unsigned getInputNum() const
static constexpr ::llvm::StringLiteral getOperationName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
FailureOr< ModuleOp > getRootModule(Operation *from)
void ensure(bool condition, llvm::Twine errMsg)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
SymbolLookupResult< StructDefOp > getStructDef(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
size_t operator()(const ConstrainRef &val) const