LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRef.cpp
Go to the documentation of this file.
1//===-- ConstraintRef.cpp - ConstrainRef implementation ---------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
14#include "llzk/Util/Compare.h"
15#include "llzk/Util/Debug.h"
18
19using namespace mlir;
20
21namespace llzk {
22
23using namespace array;
24using namespace component;
25using namespace felt;
26using namespace function;
27using namespace polymorphic;
28using namespace string;
29
30/* ConstrainRefIndex */
31
32void ConstrainRefIndex::print(mlir::raw_ostream &os) const {
33 if (isField()) {
34 os << '@' << getField().getName();
35 } else if (isIndex()) {
36 os << getIndex();
37 } else {
38 auto r = getIndexRange();
39 os << std::get<0>(r) << ':' << std::get<1>(r);
40 }
41}
42
44 if (isField() && rhs.isField()) {
46 }
47 if (isIndex() && rhs.isIndex()) {
48 return getIndex().ult(rhs.getIndex());
49 }
50 if (isIndexRange() && rhs.isIndexRange()) {
51 auto l = getIndexRange(), r = rhs.getIndexRange();
52 auto ll = std::get<0>(l), lu = std::get<1>(l);
53 auto rl = std::get<0>(r), ru = std::get<1>(r);
54 return ll.ult(rl) || (ll == rl && lu.ult(ru));
55 }
56
57 if (isField()) {
58 return true;
59 }
60 if (isIndex() && !rhs.isField()) {
61 return true;
62 }
63
64 return false;
65}
66
68 if (c.isIndex()) {
69 return llvm::hash_value(c.getIndex());
70 } else if (c.isIndexRange()) {
71 auto r = c.getIndexRange();
72 return llvm::hash_value(std::get<0>(r)) ^ llvm::hash_value(std::get<1>(r));
73 } else {
75 }
76}
77
78/* ConstrainRef */
79
88getStructDef(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, StructType ty) {
89 auto sDef = ty.getDefinition(tables, mod);
90 ensure(
91 mlir::succeeded(sDef),
92 "could not find '" + StructDefOp::getOperationName() + "' op from struct type"
93 );
94
95 return std::move(sDef.value());
96}
97
98std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
99 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ArrayType arrayTy,
100 mlir::BlockArgument blockArg, std::vector<ConstrainRefIndex> fields = {}
101) {
102 std::vector<ConstrainRef> res;
103 // Add root item
104 res.emplace_back(blockArg, fields);
105
106 // Recurse into arrays by iterating over their elements
107 int64_t maxSz = arrayTy.getDimSize(0);
108 for (int64_t i = 0; i < maxSz; i++) {
109 auto elemTy = arrayTy.getElementType();
110
111 std::vector<ConstrainRefIndex> subFields = fields;
112 subFields.emplace_back(i);
113
114 if (auto arrayElemTy = mlir::dyn_cast<ArrayType>(elemTy)) {
115 // recurse
116 auto subRes = getAllConstrainRefs(tables, mod, arrayElemTy, blockArg, subFields);
117 res.insert(res.end(), subRes.begin(), subRes.end());
118 } else if (auto structTy = mlir::dyn_cast<StructType>(elemTy)) {
119 // recurse into struct def
120 auto subRes = getAllConstrainRefs(
121 tables, mod, getStructDef(tables, mod, structTy), blockArg, subFields
122 );
123 res.insert(res.end(), subRes.begin(), subRes.end());
124 } else {
125 // scalar type
126 res.emplace_back(blockArg, subFields);
127 }
128 }
129
130 return res;
131}
132
133std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
134 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod,
135 SymbolLookupResult<StructDefOp> structDefRes, mlir::BlockArgument blockArg,
136 std::vector<ConstrainRefIndex> fields = {}
137) {
138 std::vector<ConstrainRef> res;
139 // Add root item
140 res.emplace_back(blockArg, fields);
141 // Recurse into struct types by iterating over all their field definitions
142 for (auto f : structDefRes.get().getOps<FieldDefOp>()) {
143 std::vector<ConstrainRefIndex> subFields = fields;
144 // We want to store the FieldDefOp, but without the possibility of accidentally dropping the
145 // reference, so we need to re-lookup the symbol to create a SymbolLookupResult, which will
146 // manage the external module containing the field defs, if needed.
147 // TODO: It would be nice if we could manage module op references differently
148 // so we don't have to do this.
149 auto structDefCopy = structDefRes;
150 auto fieldLookup = lookupSymbolIn<FieldDefOp>(
151 tables, mlir::SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()),
152 std::move(structDefCopy), mod.getOperation()
153 );
154 ensure(mlir::succeeded(fieldLookup), "could not get SymbolLookupResult of existing FieldDefOp");
155 subFields.emplace_back(fieldLookup.value());
156 // Make a reference to the current field, regardless of if it is a composite
157 // type or not.
158 res.emplace_back(blockArg, subFields);
159 if (auto structTy = mlir::dyn_cast<StructType>(f.getType())) {
160 // Create refs for each field
161 auto subRes = getAllConstrainRefs(
162 tables, mod, getStructDef(tables, mod, structTy), blockArg, subFields
163 );
164 res.insert(res.end(), subRes.begin(), subRes.end());
165 } else if (auto arrayTy = mlir::dyn_cast<ArrayType>(f.getType())) {
166 // Create refs for each array element
167 auto subRes = getAllConstrainRefs(tables, mod, arrayTy, blockArg, subFields);
168 res.insert(res.end(), subRes.begin(), subRes.end());
169 }
170 }
171 return res;
172}
173
174std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
175 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, mlir::BlockArgument arg,
176 std::vector<ConstrainRefIndex> fields
177) {
178 ConstrainRef root(arg, fields);
179 auto ty = root.getType();
180 std::vector<ConstrainRef> res;
181 if (auto structTy = mlir::dyn_cast<StructType>(ty)) {
182 // recurse over fields
183 res = getAllConstrainRefs(tables, mod, getStructDef(tables, mod, structTy), arg, fields);
184 } else if (auto arrayType = mlir::dyn_cast<ArrayType>(ty)) {
185 res = getAllConstrainRefs(tables, mod, arrayType, arg, fields);
186 } else if (mlir::isa<FeltType, IndexType, StringType>(ty)) {
187 // Scalar type
188 res.emplace_back(root);
189 } else {
190 std::string err;
191 debug::Appender(err) << "unsupported type: " << ty;
192 llvm::report_fatal_error(mlir::Twine(err));
193 }
194 return res;
195}
196
197std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(StructDefOp structDef) {
198 std::vector<ConstrainRef> res;
199 // Must have a constrain function by definition.
200 FuncDefOp constrainFnOp = structDef.getConstrainFuncOp();
201
202 FailureOr<ModuleOp> modOp = getRootModule(structDef);
203 ensure(
204 mlir::succeeded(modOp),
205 "could not lookup module from struct " + mlir::Twine(structDef.getName())
206 );
207
208 mlir::SymbolTableCollection tables;
209 for (auto a : constrainFnOp.getArguments()) {
210 auto argRes = getAllConstrainRefs(tables, modOp.value(), a);
211 res.insert(res.end(), argRes.begin(), argRes.end());
212 }
213 return res;
214}
215
216std::vector<ConstrainRef>
217ConstrainRef::getAllConstrainRefs(StructDefOp structDef, FieldDefOp fieldDef) {
218 std::vector<ConstrainRef> res;
219 FuncDefOp constrainFnOp = structDef.getConstrainFuncOp();
220 ensure(
221 fieldDef->getParentOfType<StructDefOp>() == structDef,
222 "Field " + mlir::Twine(fieldDef.getName()) + " is not a field of struct " +
223 mlir::Twine(structDef.getName())
224 );
225 FailureOr<ModuleOp> modOp = getRootModule(structDef);
226 ensure(
227 mlir::succeeded(modOp),
228 "could not lookup module from struct " + mlir::Twine(structDef.getName())
229 );
230
231 // Get the self argument
232 BlockArgument self = constrainFnOp.getBody().getArgument(0);
233
234 mlir::SymbolTableCollection tables;
235 return getAllConstrainRefs(tables, modOp.value(), self, {ConstrainRefIndex(fieldDef)});
236}
237
238mlir::Type ConstrainRef::getType() const {
239 if (isConstantFelt()) {
240 return std::get<FeltConstantOp>(*constantVal).getType();
241 } else if (isConstantIndex()) {
242 return std::get<mlir::arith::ConstantIndexOp>(*constantVal).getType();
243 } else if (isTemplateConstant()) {
244 return std::get<ConstReadOp>(*constantVal).getType();
245 } else {
246 int array_derefs = 0;
247 int idx = fieldRefs.size() - 1;
248 while (idx >= 0 && fieldRefs[idx].isIndex()) {
249 array_derefs++;
250 idx--;
251 }
252
253 if (idx >= 0) {
254 mlir::Type currTy = fieldRefs[idx].getField().getType();
255 while (array_derefs > 0) {
256 currTy = mlir::dyn_cast<ArrayType>(currTy).getElementType();
257 array_derefs--;
258 }
259 return currTy;
260 } else {
261 return blockArg.getType();
262 }
263 }
264}
265
266bool ConstrainRef::isValidPrefix(const ConstrainRef &prefix) const {
267 if (isConstant()) {
268 return false;
269 }
270
271 if (blockArg != prefix.blockArg || fieldRefs.size() < prefix.fieldRefs.size()) {
272 return false;
273 }
274 for (size_t i = 0; i < prefix.fieldRefs.size(); i++) {
275 if (fieldRefs[i] != prefix.fieldRefs[i]) {
276 return false;
277 }
278 }
279 return true;
280}
281
282mlir::FailureOr<std::vector<ConstrainRefIndex>> ConstrainRef::getSuffix(const ConstrainRef &prefix
283) const {
284 if (!isValidPrefix(prefix)) {
285 return mlir::failure();
286 }
287 std::vector<ConstrainRefIndex> suffix;
288 for (size_t i = prefix.fieldRefs.size(); i < fieldRefs.size(); i++) {
289 suffix.push_back(fieldRefs[i]);
290 }
291 return suffix;
292}
293
294mlir::FailureOr<ConstrainRef>
295ConstrainRef::translate(const ConstrainRef &prefix, const ConstrainRef &other) const {
296 if (isConstant()) {
297 return *this;
298 }
299 auto suffix = getSuffix(prefix);
300 if (mlir::failed(suffix)) {
301 return mlir::failure();
302 }
303
304 auto newSignalUsage = other;
305 newSignalUsage.fieldRefs.insert(newSignalUsage.fieldRefs.end(), suffix->begin(), suffix->end());
306 return newSignalUsage;
307}
308
309void ConstrainRef::print(mlir::raw_ostream &os) const {
310 if (isConstantFelt()) {
311 os << "<felt.const: " << getConstantFeltValue() << '>';
312 } else if (isConstantIndex()) {
313 os << "<index: " << getConstantIndexValue() << '>';
314 } else if (isTemplateConstant()) {
315 auto constRead = std::get<ConstReadOp>(*constantVal);
316 auto structDefOp = constRead->getParentOfType<StructDefOp>();
317 ensure(structDefOp, "struct template should have a struct parent");
318 os << '@' << structDefOp.getName() << "<[@" << constRead.getConstName() << "]>";
319 } else {
320 ensure(isBlockArgument(), "unhandled print case");
321 os << "%arg" << getInputNum();
322 for (auto f : fieldRefs) {
323 os << "[" << f << "]";
324 }
325 }
326}
327
329 return (blockArg == rhs.blockArg) && (fieldRefs == rhs.fieldRefs) &&
330 (constantVal == rhs.constantVal);
331}
332
333// required for EquivalenceClasses usage
334bool ConstrainRef::operator<(const ConstrainRef &rhs) const {
335 if (isConstantFelt() && !rhs.isConstantFelt()) {
336 // Put all constants at the end
337 return false;
338 } else if (!isConstantFelt() && rhs.isConstantFelt()) {
339 return true;
340 } else if (isConstantFelt() && rhs.isConstantFelt()) {
341 auto lhsInt = getConstantFeltValue();
342 auto rhsInt = rhs.getConstantFeltValue();
343 auto bitWidthMax = std::max(lhsInt.getBitWidth(), rhsInt.getBitWidth());
344 return lhsInt.zext(bitWidthMax).ult(rhsInt.zext(bitWidthMax));
345 }
346
347 if (isConstantIndex() && !rhs.isConstantIndex()) {
348 // Put all constant indices next at the end
349 return false;
350 } else if (!isConstantIndex() && rhs.isConstantIndex()) {
351 return true;
352 } else if (isConstantIndex() && rhs.isConstantIndex()) {
354 }
355
356 if (isTemplateConstant() && !rhs.isTemplateConstant()) {
357 // Put all template constants next at the end
358 return false;
359 } else if (!isTemplateConstant() && rhs.isTemplateConstant()) {
360 return true;
361 } else if (isTemplateConstant() && rhs.isTemplateConstant()) {
362 auto lhsName = std::get<ConstReadOp>(*constantVal).getConstName();
363 auto rhsName = std::get<ConstReadOp>(*rhs.constantVal).getConstName();
364 return lhsName.compare(rhsName) < 0;
365 }
366
367 // both are not constants
368 ensure(isBlockArgument() && rhs.isBlockArgument(), "unhandled operator< case");
369 if (getInputNum() < rhs.getInputNum()) {
370 return true;
371 } else if (getInputNum() > rhs.getInputNum()) {
372 return false;
373 }
374
375 for (size_t i = 0; i < fieldRefs.size() && i < rhs.fieldRefs.size(); i++) {
376 if (fieldRefs[i] < rhs.fieldRefs[i]) {
377 return true;
378 } else if (fieldRefs[i] > rhs.fieldRefs[i]) {
379 return false;
380 }
381 }
382 return fieldRefs.size() < rhs.fieldRefs.size();
383}
384
386 if (val.isConstantFelt()) {
387 return OpHash<FeltConstantOp> {}(std::get<FeltConstantOp>(*val.constantVal));
388 } else if (val.isConstantIndex()) {
390 }(std::get<mlir::arith::ConstantIndexOp>(*val.constantVal));
391 } else if (val.isTemplateConstant()) {
392 return OpHash<ConstReadOp> {}(std::get<ConstReadOp>(*val.constantVal));
393 } else {
394 ensure(val.isBlockArgument(), "unhandled operator() case");
395
396 size_t hash = std::hash<unsigned> {}(val.getInputNum());
397 for (auto f : val.fieldRefs) {
398 hash ^= f.getHash();
399 }
400 return hash;
401 }
402}
403
404mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs) {
405 rhs.print(os);
406 return os;
407}
408
409/* ConstrainRefSet */
410
412 insert(rhs.begin(), rhs.end());
413 return *this;
414}
415
416mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs) {
417 os << "{ ";
418 std::vector<ConstrainRef> sortedRefs(rhs.begin(), rhs.end());
419 std::sort(sortedRefs.begin(), sortedRefs.end());
420 for (auto it = sortedRefs.begin(); it != sortedRefs.end();) {
421 os << *it;
422 it++;
423 if (it != sortedRefs.end()) {
424 os << ", ";
425 } else {
426 os << ' ';
427 }
428 }
429 os << '}';
430 return os;
431}
432
433} // namespace llzk
This file defines methods symbol lookup across LLZK operations and included files.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
bool operator<(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(component::FieldDefOp f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool operator<(const ConstrainRef &rhs) const
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isConstant() const
bool isTemplateConstant() const
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
ConstrainRef(mlir::BlockArgument b)
bool operator==(const ConstrainRef &rhs) const
bool isBlockArgument() const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::Type getType() const
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
unsigned getInputNum() const
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:919
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
FailureOr< ModuleOp > getRootModule(Operation *from)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
SymbolLookupResult< StructDefOp > getStructDef(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
size_t operator()(const ConstrainRefIndex &c) const
size_t operator()(const ConstrainRef &val) const