LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRef.cpp
Go to the documentation of this file.
1//===-- ConstraintRef.cpp - ConstrainRef implementation ---------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
14#include "llzk/Util/Compare.h"
15#include "llzk/Util/Debug.h"
18
19using namespace mlir;
20
21namespace llzk {
22
23using namespace array;
24using namespace component;
25using namespace felt;
26using namespace polymorphic;
27using namespace string;
28
29/* ConstrainRefIndex */
30
31void ConstrainRefIndex::print(mlir::raw_ostream &os) const {
32 if (isField()) {
33 os << '@' << getField().getName();
34 } else if (isIndex()) {
35 os << getIndex();
36 } else {
37 auto r = getIndexRange();
38 os << std::get<0>(r) << ':' << std::get<1>(r);
39 }
40}
41
43 if (isField() && rhs.isField()) {
45 }
46 if (isIndex() && rhs.isIndex()) {
47 return getIndex().ult(rhs.getIndex());
48 }
49 if (isIndexRange() && rhs.isIndexRange()) {
50 auto l = getIndexRange(), r = rhs.getIndexRange();
51 auto ll = std::get<0>(l), lu = std::get<1>(l);
52 auto rl = std::get<0>(r), ru = std::get<1>(r);
53 return ll.ult(rl) || (ll == rl && lu.ult(ru));
54 }
55
56 if (isField()) {
57 return true;
58 }
59 if (isIndex() && !rhs.isField()) {
60 return true;
61 }
62
63 return false;
64}
65
66/* ConstrainRef */
67
76getStructDef(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, StructType ty) {
77 auto sDef = ty.getDefinition(tables, mod);
78 ensure(
79 mlir::succeeded(sDef),
80 "could not find '" + StructDefOp::getOperationName() + "' op from struct type"
81 );
82
83 return std::move(sDef.value());
84}
85
86std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
87 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ArrayType arrayTy,
88 mlir::BlockArgument blockArg, std::vector<ConstrainRefIndex> fields = {}
89) {
90 std::vector<ConstrainRef> res;
91 // Add root item
92 res.emplace_back(blockArg, fields);
93
94 // Recurse into arrays by iterating over their elements
95 int64_t maxSz = arrayTy.getDimSize(0);
96 for (int64_t i = 0; i < maxSz; i++) {
97 auto elemTy = arrayTy.getElementType();
98
99 std::vector<ConstrainRefIndex> subFields = fields;
100 subFields.emplace_back(i);
101
102 if (auto arrayElemTy = mlir::dyn_cast<ArrayType>(elemTy)) {
103 // recurse
104 auto subRes = getAllConstrainRefs(tables, mod, arrayElemTy, blockArg, subFields);
105 res.insert(res.end(), subRes.begin(), subRes.end());
106 } else if (auto structTy = mlir::dyn_cast<StructType>(elemTy)) {
107 // recurse into struct def
108 auto subRes = getAllConstrainRefs(
109 tables, mod, getStructDef(tables, mod, structTy), blockArg, subFields
110 );
111 res.insert(res.end(), subRes.begin(), subRes.end());
112 } else {
113 // scalar type
114 res.emplace_back(blockArg, subFields);
115 }
116 }
117
118 return res;
119}
120
121std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
122 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod,
123 SymbolLookupResult<StructDefOp> structDefRes, mlir::BlockArgument blockArg,
124 std::vector<ConstrainRefIndex> fields = {}
125) {
126 std::vector<ConstrainRef> res;
127 // Add root item
128 res.emplace_back(blockArg, fields);
129 // Recurse into struct types by iterating over all their field definitions
130 for (auto f : structDefRes.get().getOps<FieldDefOp>()) {
131 std::vector<ConstrainRefIndex> subFields = fields;
132 // We want to store the FieldDefOp, but without the possibility of accidentally dropping the
133 // reference, so we need to re-lookup the symbol to create a SymbolLookupResult, which will
134 // manage the external module containing the field defs, if needed.
135 // TODO: It would be nice if we could manage module op references differently
136 // so we don't have to do this.
137 auto structDefCopy = structDefRes;
138 auto fieldLookup = lookupSymbolIn<FieldDefOp>(
139 tables, mlir::SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()),
140 std::move(structDefCopy), mod.getOperation()
141 );
142 ensure(mlir::succeeded(fieldLookup), "could not get SymbolLookupResult of existing FieldDefOp");
143 subFields.emplace_back(fieldLookup.value());
144 // Make a reference to the current field, regardless of if it is a composite
145 // type or not.
146 res.emplace_back(blockArg, subFields);
147 if (auto structTy = mlir::dyn_cast<StructType>(f.getType())) {
148 // Create refs for each field
149 auto subRes = getAllConstrainRefs(
150 tables, mod, getStructDef(tables, mod, structTy), blockArg, subFields
151 );
152 res.insert(res.end(), subRes.begin(), subRes.end());
153 } else if (auto arrayTy = mlir::dyn_cast<ArrayType>(f.getType())) {
154 // Create refs for each array element
155 auto subRes = getAllConstrainRefs(tables, mod, arrayTy, blockArg, subFields);
156 res.insert(res.end(), subRes.begin(), subRes.end());
157 }
158 }
159 return res;
160}
161
162std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(
163 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, mlir::BlockArgument arg
164) {
165 auto ty = arg.getType();
166 std::vector<ConstrainRef> res;
167 if (auto structTy = mlir::dyn_cast<StructType>(ty)) {
168 // recurse over fields
169 res = getAllConstrainRefs(tables, mod, getStructDef(tables, mod, structTy), arg);
170 } else if (auto arrayType = mlir::dyn_cast<ArrayType>(ty)) {
171 res = getAllConstrainRefs(tables, mod, arrayType, arg);
172 } else if (mlir::isa<FeltType, IndexType, StringType>(ty)) {
173 // Scalar type
174 res.emplace_back(arg);
175 } else {
176 std::string err;
177 debug::Appender(err) << "unsupported type: " << ty;
178 llvm::report_fatal_error(mlir::Twine(err));
179 }
180 return res;
181}
182
183std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(StructDefOp structDef) {
184 std::vector<ConstrainRef> res;
185 auto constrainFnOp = structDef.getConstrainFuncOp();
186 ensure(
187 constrainFnOp,
188 "malformed struct " + mlir::Twine(structDef.getName()) + " must define a constrain function"
189 );
190
191 auto modOp = getRootModule(structDef);
192 ensure(
193 mlir::succeeded(modOp),
194 "could not lookup module from struct " + mlir::Twine(structDef.getName())
195 );
196
197 mlir::SymbolTableCollection tables;
198 for (auto a : constrainFnOp.getArguments()) {
199 auto argRes = getAllConstrainRefs(tables, modOp.value(), a);
200 res.insert(res.end(), argRes.begin(), argRes.end());
201 }
202 return res;
203}
204
205mlir::Type ConstrainRef::getType() const {
206 if (isConstantFelt()) {
207 return std::get<FeltConstantOp>(*constantVal).getType();
208 } else if (isConstantIndex()) {
209 return std::get<mlir::arith::ConstantIndexOp>(*constantVal).getType();
210 } else if (isTemplateConstant()) {
211 return std::get<ConstReadOp>(*constantVal).getType();
212 } else {
213 int array_derefs = 0;
214 int idx = fieldRefs.size() - 1;
215 while (idx >= 0 && fieldRefs[idx].isIndex()) {
216 array_derefs++;
217 idx--;
218 }
219
220 if (idx >= 0) {
221 mlir::Type currTy = fieldRefs[idx].getField().getType();
222 while (array_derefs > 0) {
223 currTy = mlir::dyn_cast<ArrayType>(currTy).getElementType();
224 array_derefs--;
225 }
226 return currTy;
227 } else {
228 return blockArg.getType();
229 }
230 }
231}
232
233bool ConstrainRef::isValidPrefix(const ConstrainRef &prefix) const {
234 if (isConstant()) {
235 return false;
236 }
237
238 if (blockArg != prefix.blockArg || fieldRefs.size() < prefix.fieldRefs.size()) {
239 return false;
240 }
241 for (size_t i = 0; i < prefix.fieldRefs.size(); i++) {
242 if (fieldRefs[i] != prefix.fieldRefs[i]) {
243 return false;
244 }
245 }
246 return true;
247}
248
249mlir::FailureOr<std::vector<ConstrainRefIndex>> ConstrainRef::getSuffix(const ConstrainRef &prefix
250) const {
251 if (!isValidPrefix(prefix)) {
252 return mlir::failure();
253 }
254 std::vector<ConstrainRefIndex> suffix;
255 for (size_t i = prefix.fieldRefs.size(); i < fieldRefs.size(); i++) {
256 suffix.push_back(fieldRefs[i]);
257 }
258 return suffix;
259}
260
261mlir::FailureOr<ConstrainRef>
262ConstrainRef::translate(const ConstrainRef &prefix, const ConstrainRef &other) const {
263 if (isConstant()) {
264 return *this;
265 }
266 auto suffix = getSuffix(prefix);
267 if (mlir::failed(suffix)) {
268 return mlir::failure();
269 }
270
271 auto newSignalUsage = other;
272 newSignalUsage.fieldRefs.insert(newSignalUsage.fieldRefs.end(), suffix->begin(), suffix->end());
273 return newSignalUsage;
274}
275
276void ConstrainRef::print(mlir::raw_ostream &os) const {
277 if (isConstantFelt()) {
278 os << "<felt.const: " << getConstantFeltValue() << '>';
279 } else if (isConstantIndex()) {
280 os << "<index: " << getConstantIndexValue() << '>';
281 } else if (isTemplateConstant()) {
282 auto constRead = std::get<ConstReadOp>(*constantVal);
283 auto structDefOp = constRead->getParentOfType<StructDefOp>();
284 ensure(structDefOp, "struct template should have a struct parent");
285 os << '@' << structDefOp.getName() << "<[@" << constRead.getConstName() << "]>";
286 } else {
287 ensure(isBlockArgument(), "unhandled print case");
288 os << "%arg" << getInputNum();
289 for (auto f : fieldRefs) {
290 os << "[" << f << "]";
291 }
292 }
293}
294
296 return (blockArg == rhs.blockArg) && (fieldRefs == rhs.fieldRefs) &&
297 (constantVal == rhs.constantVal);
298}
299
300// required for EquivalenceClasses usage
301bool ConstrainRef::operator<(const ConstrainRef &rhs) const {
302 if (isConstantFelt() && !rhs.isConstantFelt()) {
303 // Put all constants at the end
304 return false;
305 } else if (!isConstantFelt() && rhs.isConstantFelt()) {
306 return true;
307 } else if (isConstantFelt() && rhs.isConstantFelt()) {
308 auto lhsInt = getConstantFeltValue();
309 auto rhsInt = rhs.getConstantFeltValue();
310 auto bitWidthMax = std::max(lhsInt.getBitWidth(), rhsInt.getBitWidth());
311 return lhsInt.zext(bitWidthMax).ult(rhsInt.zext(bitWidthMax));
312 }
313
314 if (isConstantIndex() && !rhs.isConstantIndex()) {
315 // Put all constant indices next at the end
316 return false;
317 } else if (!isConstantIndex() && rhs.isConstantIndex()) {
318 return true;
319 } else if (isConstantIndex() && rhs.isConstantIndex()) {
321 }
322
323 if (isTemplateConstant() && !rhs.isTemplateConstant()) {
324 // Put all template constants next at the end
325 return false;
326 } else if (!isTemplateConstant() && rhs.isTemplateConstant()) {
327 return true;
328 } else if (isTemplateConstant() && rhs.isTemplateConstant()) {
329 auto lhsName = std::get<ConstReadOp>(*constantVal).getConstName();
330 auto rhsName = std::get<ConstReadOp>(*rhs.constantVal).getConstName();
331 return lhsName.compare(rhsName) < 0;
332 }
333
334 // both are not constants
335 ensure(isBlockArgument() && rhs.isBlockArgument(), "unhandled operator< case");
336 if (getInputNum() < rhs.getInputNum()) {
337 return true;
338 } else if (getInputNum() > rhs.getInputNum()) {
339 return false;
340 }
341
342 for (size_t i = 0; i < fieldRefs.size() && i < rhs.fieldRefs.size(); i++) {
343 if (fieldRefs[i] < rhs.fieldRefs[i]) {
344 return true;
345 } else if (fieldRefs[i] > rhs.fieldRefs[i]) {
346 return false;
347 }
348 }
349 return fieldRefs.size() < rhs.fieldRefs.size();
350}
351
353 if (val.isConstantFelt()) {
354 return OpHash<FeltConstantOp> {}(std::get<FeltConstantOp>(*val.constantVal));
355 } else if (val.isConstantIndex()) {
357 }(std::get<mlir::arith::ConstantIndexOp>(*val.constantVal));
358 } else if (val.isTemplateConstant()) {
359 return OpHash<ConstReadOp> {}(std::get<ConstReadOp>(*val.constantVal));
360 } else {
361 ensure(val.isBlockArgument(), "unhandled operator() case");
362
363 size_t hash = std::hash<unsigned> {}(val.getInputNum());
364 for (auto f : val.fieldRefs) {
365 hash ^= f.getHash();
366 }
367 return hash;
368 }
369}
370
371mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs) {
372 rhs.print(os);
373 return os;
374}
375
376/* ConstrainRefSet */
377
379 insert(rhs.begin(), rhs.end());
380 return *this;
381}
382
383mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs) {
384 os << "{ ";
385 std::vector<ConstrainRef> sortedRefs(rhs.begin(), rhs.end());
386 std::sort(sortedRefs.begin(), sortedRefs.end());
387 for (auto it = sortedRefs.begin(); it != sortedRefs.end();) {
388 os << *it;
389 it++;
390 if (it != sortedRefs.end()) {
391 os << ", ";
392 } else {
393 os << ' ';
394 }
395 }
396 os << '}';
397 return os;
398}
399
400} // namespace llzk
This file defines methods symbol lookup across LLZK operations and included files.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
bool operator<(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(SymbolLookupResult< component::FieldDefOp > f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool operator<(const ConstrainRef &rhs) const
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isConstant() const
bool isTemplateConstant() const
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
ConstrainRef(mlir::BlockArgument b)
bool operator==(const ConstrainRef &rhs) const
bool isBlockArgument() const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::Type getType() const
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
unsigned getInputNum() const
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:919
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
FailureOr< ModuleOp > getRootModule(Operation *from)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
SymbolLookupResult< StructDefOp > getStructDef(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
size_t operator()(const ConstrainRef &val) const