LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRef.cpp
Go to the documentation of this file.
1//===-- ConstraintRef.cpp - ConstrainRef implementation ---------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
16#include "llzk/Util/Compare.h"
17#include "llzk/Util/Debug.h"
20
21using namespace mlir;
22
23namespace llzk {
24
25using namespace array;
26using namespace component;
27using namespace felt;
28using namespace function;
29using namespace polymorphic;
30using namespace string;
31
32/* ConstrainRefIndex */
33
34void ConstrainRefIndex::print(raw_ostream &os) const {
35 if (isField()) {
36 os << '@' << getField().getName();
37 } else if (isIndex()) {
38 os << getIndex();
39 } else {
40 auto [low, high] = getIndexRange();
41 if (ShapedType::isDynamic(high.getSExtValue())) {
42 os << "<dynamic>";
43 } else {
44 os << low << ':' << high;
45 }
46 }
47}
48
50 if (isField() && rhs.isField()) {
52 }
53 if (isIndex() && rhs.isIndex()) {
54 return safeLt(APSInt(getIndex()), APSInt(rhs.getIndex()));
55 }
56 if (isIndexRange() && rhs.isIndexRange()) {
57 auto l = getIndexRange(), r = rhs.getIndexRange();
58 auto ll = APSInt(std::get<0>(l)), lu = APSInt(std::get<1>(l));
59 auto rl = APSInt(std::get<0>(r)), ru = APSInt(std::get<1>(r));
60 return safeLt(ll, rl) || (safeEq(ll, rl) && safeLt(lu, ru));
61 }
62
63 if (isField()) {
64 return true;
65 }
66 if (isIndex() && !rhs.isField()) {
67 return true;
68 }
69
70 return false;
71}
72
74 if (c.isIndex()) {
75 // We don't hash the index directly, because the built-in LLVM hash includes
76 // the bitwidth of the APInt in the hash, which is undesirable for this application.
77 // i.e., We want a N-bit version of x to hash to the same value as an M-bit version of X,
78 // because our equality checks would consider them equal regardless of bitwidth.
79 APInt idx = c.getIndex();
80 unsigned requiredBits = idx.getSignificantBits();
81 auto hash = llvm::hash_value(idx.trunc(requiredBits));
82 return hash;
83 } else if (c.isIndexRange()) {
84 auto r = c.getIndexRange();
85 return llvm::hash_value(std::get<0>(r)) ^ llvm::hash_value(std::get<1>(r));
86 } else {
88 }
89}
90
91/* ConstrainRef */
92
101getStructDef(SymbolTableCollection &tables, ModuleOp mod, StructType ty) {
102 auto sDef = ty.getDefinition(tables, mod);
103 ensure(
104 succeeded(sDef),
105 "could not find '" + StructDefOp::getOperationName() + "' op from struct type"
106 );
107
108 return std::move(sDef.value());
109}
110
111std::vector<ConstrainRef>
112ConstrainRef::getAllConstrainRefs(SymbolTableCollection &tables, ModuleOp mod, ConstrainRef root) {
113 std::vector<ConstrainRef> res = {root};
114 for (const ConstrainRef &child : root.getAllChildren(tables, mod)) {
115 auto recursiveChildren = getAllConstrainRefs(tables, mod, child);
116 res.insert(res.end(), recursiveChildren.begin(), recursiveChildren.end());
117 }
118 return res;
119}
120
121std::vector<ConstrainRef> ConstrainRef::getAllConstrainRefs(StructDefOp structDef, FuncDefOp fnOp) {
122 std::vector<ConstrainRef> res;
123
124 ensure(
125 structDef == fnOp->getParentOfType<StructDefOp>(), "function must be within the given struct"
126 );
127
128 FailureOr<ModuleOp> modOp = getRootModule(structDef);
129 ensure(succeeded(modOp), "could not lookup module from struct " + Twine(structDef.getName()));
130
131 SymbolTableCollection tables;
132 for (auto a : fnOp.getArguments()) {
133 auto argRes = getAllConstrainRefs(tables, modOp.value(), ConstrainRef(a));
134 res.insert(res.end(), argRes.begin(), argRes.end());
135 }
136
137 // For compute functions, the "self" field is not arg0 like for constrain, but
138 // rather the struct value returned from the function.
139 if (fnOp.isStructCompute()) {
140 Value selfVal = fnOp.getSelfValueFromCompute();
141 auto createOp = dyn_cast_if_present<CreateStructOp>(selfVal.getDefiningOp());
142 ensure(createOp, "self value should originate from struct.new operation");
143 auto selfRes = getAllConstrainRefs(tables, modOp.value(), ConstrainRef(createOp));
144 res.insert(res.end(), selfRes.begin(), selfRes.end());
145 }
146
147 return res;
148}
149
150std::vector<ConstrainRef>
152 std::vector<ConstrainRef> res;
153 FuncDefOp constrainFnOp = structDef.getConstrainFuncOp();
154 ensure(
155 fieldDef->getParentOfType<StructDefOp>() == structDef, "Field " + Twine(fieldDef.getName()) +
156 " is not a field of struct " +
157 Twine(structDef.getName())
158 );
159 FailureOr<ModuleOp> modOp = getRootModule(structDef);
160 ensure(succeeded(modOp), "could not lookup module from struct " + Twine(structDef.getName()));
161
162 // Get the self argument (like `FuncDefOp::getSelfValueFromConstrain()`)
163 BlockArgument self = constrainFnOp.getArguments().front();
164 ConstrainRef fieldRef = ConstrainRef(self, {ConstrainRefIndex(fieldDef)});
165
166 SymbolTableCollection tables;
167 return getAllConstrainRefs(tables, modOp.value(), fieldRef);
168}
169
171 if (isConstantFelt()) {
172 return std::get<FeltConstantOp>(*constantVal).getType();
173 } else if (isConstantIndex()) {
174 return std::get<arith::ConstantIndexOp>(*constantVal).getType();
175 } else if (isTemplateConstant()) {
176 return std::get<ConstReadOp>(*constantVal).getType();
177 } else {
178 int array_derefs = 0;
179 int idx = fieldRefs.size() - 1;
180 while (idx >= 0 && fieldRefs[idx].isIndex()) {
181 array_derefs++;
182 idx--;
183 }
184
185 Type currTy = nullptr;
186 if (idx >= 0) {
187 currTy = fieldRefs[idx].getField().getType();
188 } else {
189 currTy = isBlockArgument() ? getBlockArgument().getType() : getCreateStructOp().getType();
190 }
191
192 while (array_derefs > 0) {
193 currTy = dyn_cast<ArrayType>(currTy).getElementType();
194 array_derefs--;
195 }
196 return currTy;
197 }
198}
199
200bool ConstrainRef::isValidPrefix(const ConstrainRef &prefix) const {
201 if (isConstant()) {
202 return false;
203 }
204
205 if (root != prefix.root || fieldRefs.size() < prefix.fieldRefs.size()) {
206 return false;
207 }
208 for (size_t i = 0; i < prefix.fieldRefs.size(); i++) {
209 if (fieldRefs[i] != prefix.fieldRefs[i]) {
210 return false;
211 }
212 }
213 return true;
214}
215
216FailureOr<std::vector<ConstrainRefIndex>> ConstrainRef::getSuffix(const ConstrainRef &prefix
217) const {
218 if (!isValidPrefix(prefix)) {
219 return failure();
220 }
221 std::vector<ConstrainRefIndex> suffix;
222 for (size_t i = prefix.fieldRefs.size(); i < fieldRefs.size(); i++) {
223 suffix.push_back(fieldRefs[i]);
224 }
225 return suffix;
226}
227
228FailureOr<ConstrainRef>
229ConstrainRef::translate(const ConstrainRef &prefix, const ConstrainRef &other) const {
230 if (isConstant()) {
231 return *this;
232 }
233 auto suffix = getSuffix(prefix);
234 if (failed(suffix)) {
235 return failure();
236 }
237
238 auto newSignalUsage = other;
239 newSignalUsage.fieldRefs.insert(newSignalUsage.fieldRefs.end(), suffix->begin(), suffix->end());
240 return newSignalUsage;
241}
242
243std::vector<ConstrainRef>
244getAllChildren(SymbolTableCollection &tables, ModuleOp mod, ArrayType arrayTy, ConstrainRef root) {
245 std::vector<ConstrainRef> res;
246 // Recurse into arrays by iterating over their elements
247 for (int64_t i = 0; i < arrayTy.getDimSize(0); i++) {
248 ConstrainRef childRef = root.createChild(ConstrainRefIndex(i));
249 res.push_back(childRef);
250 }
251
252 return res;
253}
254
255std::vector<ConstrainRef> getAllChildren(
256 SymbolTableCollection &tables, ModuleOp mod, SymbolLookupResult<StructDefOp> structDefRes,
257 ConstrainRef root
258) {
259 std::vector<ConstrainRef> res;
260 // Recurse into struct types by iterating over all their field definitions
261 for (auto f : structDefRes.get().getOps<FieldDefOp>()) {
262 // We want to store the FieldDefOp, but without the possibility of accidentally dropping the
263 // reference, so we need to re-lookup the symbol to create a SymbolLookupResult, which will
264 // manage the external module containing the field defs, if needed.
265 // TODO: It would be nice if we could manage module op references differently
266 // so we don't have to do this.
267 auto structDefCopy = structDefRes;
268 auto fieldLookup = lookupSymbolIn<FieldDefOp>(
269 tables, SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()), std::move(structDefCopy),
270 mod.getOperation()
271 );
272 ensure(succeeded(fieldLookup), "could not get SymbolLookupResult of existing FieldDefOp");
273 ConstrainRef childRef = root.createChild(ConstrainRefIndex(fieldLookup.value()));
274 // Make a reference to the current field, regardless of if it is a composite
275 // type or not.
276 res.push_back(childRef);
277 }
278 return res;
279}
280
281std::vector<ConstrainRef>
282ConstrainRef::getAllChildren(SymbolTableCollection &tables, ModuleOp mod) const {
283 auto ty = getType();
284 if (auto structTy = dyn_cast<StructType>(ty)) {
285 return llzk::getAllChildren(tables, mod, getStructDef(tables, mod, structTy), *this);
286 } else if (auto arrayType = dyn_cast<ArrayType>(ty)) {
287 return llzk::getAllChildren(tables, mod, arrayType, *this);
288 }
289 // Scalar type, no children
290 return {};
291}
292
293void ConstrainRef::print(raw_ostream &os) const {
294 if (isConstantFelt()) {
295 os << "<felt.const: " << getConstantFeltValue() << '>';
296 } else if (isConstantIndex()) {
297 os << "<index: " << getConstantIndexValue() << '>';
298 } else if (isTemplateConstant()) {
299 auto constRead = std::get<ConstReadOp>(*constantVal);
300 auto structDefOp = constRead->getParentOfType<StructDefOp>();
301 ensure(structDefOp, "struct template should have a struct parent");
302 os << '@' << structDefOp.getName() << "<[@" << constRead.getConstName() << "]>";
303 } else {
304 if (isCreateStructOp()) {
305 os << "%self";
306 } else {
307 ensure(isBlockArgument(), "unhandled print case");
308 os << "%arg" << getInputNum();
309 }
310
311 for (auto f : fieldRefs) {
312 os << "[" << f << "]";
313 }
314 }
315}
316
318 // This way two felt constants can be equal even if the declared in separate ops.
319 if (isConstantInt() && rhs.isConstantInt()) {
320 APSInt lhsVal(getConstantValue()), rhsVal(rhs.getConstantValue());
321 return getType() == rhs.getType() && safeEq(lhsVal, rhsVal);
322 }
323 return (root == rhs.root) && (fieldRefs == rhs.fieldRefs) && (constantVal == rhs.constantVal);
324}
325
326// required for EquivalenceClasses usage
327bool ConstrainRef::operator<(const ConstrainRef &rhs) const {
328 if (isConstantFelt() && !rhs.isConstantFelt()) {
329 // Put all constants at the end
330 return false;
331 } else if (!isConstantFelt() && rhs.isConstantFelt()) {
332 return true;
333 } else if (isConstantFelt() && rhs.isConstantFelt()) {
334 APSInt lhsInt(getConstantFeltValue()), rhsInt(rhs.getConstantFeltValue());
335 return safeLt(lhsInt, rhsInt);
336 }
337
338 if (isConstantIndex() && !rhs.isConstantIndex()) {
339 // Put all constant indices next at the end
340 return false;
341 } else if (!isConstantIndex() && rhs.isConstantIndex()) {
342 return true;
343 } else if (isConstantIndex() && rhs.isConstantIndex()) {
344 APSInt lhsVal(getConstantIndexValue()), rhsVal(rhs.getConstantIndexValue());
345 return safeLt(lhsVal, rhsVal);
346 }
347
348 if (isTemplateConstant() && !rhs.isTemplateConstant()) {
349 // Put all template constants next at the end
350 return false;
351 } else if (!isTemplateConstant() && rhs.isTemplateConstant()) {
352 return true;
353 } else if (isTemplateConstant() && rhs.isTemplateConstant()) {
354 auto lhsName = std::get<ConstReadOp>(*constantVal).getConstName();
355 auto rhsName = std::get<ConstReadOp>(*rhs.constantVal).getConstName();
356 return lhsName.compare(rhsName) < 0;
357 }
358
359 // Sort out the block argument vs struct.new cases
360 if (isBlockArgument() && rhs.isCreateStructOp()) {
361 return true;
362 } else if (isCreateStructOp() && rhs.isBlockArgument()) {
363 return false;
364 } else if (isBlockArgument() && rhs.isBlockArgument()) {
365 if (getInputNum() < rhs.getInputNum()) {
366 return true;
367 } else if (getInputNum() > rhs.getInputNum()) {
368 return false;
369 }
370 } else if (isCreateStructOp() && rhs.isCreateStructOp()) {
371 CreateStructOp lhsOp = getCreateStructOp(), rhsOp = rhs.getCreateStructOp();
372 if (lhsOp < rhsOp) {
373 return true;
374 } else if (lhsOp > rhsOp) {
375 return false;
376 }
377 } else {
378 llvm_unreachable("unhandled operator< case");
379 }
380
381 for (size_t i = 0; i < fieldRefs.size() && i < rhs.fieldRefs.size(); i++) {
382 if (fieldRefs[i] < rhs.fieldRefs[i]) {
383 return true;
384 } else if (fieldRefs[i] > rhs.fieldRefs[i]) {
385 return false;
386 }
387 }
388 return fieldRefs.size() < rhs.fieldRefs.size();
389}
390
392 if (val.isConstantInt()) {
393 return llvm::hash_value(val.getConstantValue());
394 } else if (val.isTemplateConstant()) {
395 return OpHash<ConstReadOp> {}(std::get<ConstReadOp>(*val.constantVal));
396 } else {
397 ensure(val.isBlockArgument() || val.isCreateStructOp(), "unhandled ConstrainRef hash case");
398
399 size_t hash = val.isBlockArgument() ? std::hash<unsigned> {}(val.getInputNum())
401 for (auto f : val.fieldRefs) {
402 hash ^= f.getHash();
403 }
404 return hash;
405 }
406}
407
408raw_ostream &operator<<(raw_ostream &os, const ConstrainRef &rhs) {
409 rhs.print(os);
410 return os;
411}
412
413/* ConstrainRefSet */
414
416 insert(rhs.begin(), rhs.end());
417 return *this;
418}
419
420raw_ostream &operator<<(raw_ostream &os, const ConstrainRefSet &rhs) {
421 os << "{ ";
422 std::vector<ConstrainRef> sortedRefs(rhs.begin(), rhs.end());
423 std::sort(sortedRefs.begin(), sortedRefs.end());
424 for (auto it = sortedRefs.begin(); it != sortedRefs.end();) {
425 os << *it;
426 it++;
427 if (it != sortedRefs.end()) {
428 os << ", ";
429 } else {
430 os << ' ';
431 }
432 }
433 os << '}';
434 return os;
435}
436
437} // namespace llzk
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
This file defines methods symbol lookup across LLZK operations and included files.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
bool operator<(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(component::FieldDefOp f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool operator<(const ConstrainRef &rhs) const
component::CreateStructOp getCreateStructOp() const
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isConstant() const
bool isTemplateConstant() const
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
mlir::Type getType() const
std::vector< ConstrainRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this ConstrainRef, assuming this ref is not a scalar.
ConstrainRef(mlir::BlockArgument b)
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool operator==(const ConstrainRef &rhs) const
bool isBlockArgument() const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::BlockArgument getBlockArgument() const
ConstrainRef createChild(ConstrainRefIndex r) const
bool isCreateStructOp() const
static std::vector< ConstrainRef > getAllConstrainRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ConstrainRef root)
Produce all possible ConstraintRefs that are present starting from the given root.
bool isConstantInt() const
mlir::APInt getConstantValue() const
unsigned getInputNum() const
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:919
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:314
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:638
FailureOr< ModuleOp > getRootModule(Operation *from)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:74
bool safeLt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:66
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
raw_ostream & operator<<(raw_ostream &os, const ConstrainRef &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
SymbolLookupResult< StructDefOp > getStructDef(SymbolTableCollection &tables, ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
std::vector< ConstrainRef > getAllChildren(SymbolTableCollection &tables, ModuleOp mod, ArrayType arrayTy, ConstrainRef root)
size_t operator()(const ConstrainRefIndex &c) const
size_t operator()(const ConstrainRef &val) const