LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SourceRef.cpp
Go to the documentation of this file.
1//===-- SourceRef.cpp - SourceRef implementation ----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Compare.h"
16#include "llzk/Util/Debug.h"
20
21using namespace mlir;
22
23namespace llzk {
24
25using namespace array;
26using namespace component;
27using namespace felt;
28using namespace function;
29using namespace polymorphic;
30using namespace string;
31
32/* SourceRefIndex */
33
34void SourceRefIndex::print(raw_ostream &os) const {
35 if (isField()) {
36 os << '@' << getField().getName();
37 } else if (isIndex()) {
38 os << getIndex();
39 } else {
40 auto [low, high] = getIndexRange();
41 if (ShapedType::isDynamic(int64_t(high))) {
42 os << "<dynamic>";
43 } else {
44 os << low << ':' << high;
45 }
46 }
47}
48
50 if (isField() && rhs.isField()) {
52 }
53 if (isIndex() && rhs.isIndex()) {
54 return getIndex() < rhs.getIndex();
55 }
56 if (isIndexRange() && rhs.isIndexRange()) {
57 auto [ll, lu] = getIndexRange();
58 auto [rl, ru] = rhs.getIndexRange();
59 return ll < rl || (ll == rl && lu < ru);
60 }
61
62 if (isField()) {
63 return true;
64 }
65 if (isIndex() && !rhs.isField()) {
66 return true;
67 }
68
69 return false;
70}
71
73 if (c.isIndex()) {
74 // We don't hash the index directly, because the built-in LLVM hash includes
75 // the bitwidth of the APInt in the hash, which is undesirable for this application.
76 // i.e., We want a N-bit version of x to hash to the same value as an M-bit version of X,
77 // because our equality checks would consider them equal regardless of bitwidth.
78 APSInt idx = toAPSInt(c.getIndex());
79 unsigned requiredBits = idx.getSignificantBits();
80 auto hash = llvm::hash_value(idx.trunc(requiredBits));
81 return hash;
82 } else if (c.isIndexRange()) {
83 auto r = c.getIndexRange();
84 return llvm::hash_value(std::get<0>(r)) ^ llvm::hash_value(std::get<1>(r));
85 } else {
87 }
88}
89
90/* SourceRef */
91
100getStructDef(SymbolTableCollection &tables, ModuleOp mod, StructType ty) {
101 auto sDef = ty.getDefinition(tables, mod);
102 ensure(
103 succeeded(sDef),
104 "could not find '" + StructDefOp::getOperationName() + "' op from struct type"
105 );
106
107 return std::move(sDef.value());
108}
109
110std::vector<SourceRef>
111SourceRef::getAllSourceRefs(SymbolTableCollection &tables, ModuleOp mod, SourceRef root) {
112 std::vector<SourceRef> res = {root};
113 for (const SourceRef &child : root.getAllChildren(tables, mod)) {
114 auto recursiveChildren = getAllSourceRefs(tables, mod, child);
115 res.insert(res.end(), recursiveChildren.begin(), recursiveChildren.end());
116 }
117 return res;
118}
119
120std::vector<SourceRef> SourceRef::getAllSourceRefs(StructDefOp structDef, FuncDefOp fnOp) {
121 std::vector<SourceRef> res;
122
123 ensure(
124 structDef == fnOp->getParentOfType<StructDefOp>(), "function must be within the given struct"
125 );
126
127 FailureOr<ModuleOp> modOp = getRootModule(structDef);
128 ensure(succeeded(modOp), "could not lookup module from struct " + Twine(structDef.getName()));
129
130 SymbolTableCollection tables;
131 for (auto a : fnOp.getArguments()) {
132 auto argRes = getAllSourceRefs(tables, modOp.value(), SourceRef(a));
133 res.insert(res.end(), argRes.begin(), argRes.end());
134 }
135
136 // For compute functions, the "self" field is not arg0 like for constrain, but
137 // rather the struct value returned from the function.
138 if (fnOp.isStructCompute()) {
139 Value selfVal = fnOp.getSelfValueFromCompute();
140 auto createOp = dyn_cast_if_present<CreateStructOp>(selfVal.getDefiningOp());
141 ensure(createOp, "self value should originate from struct.new operation");
142 auto selfRes = getAllSourceRefs(tables, modOp.value(), SourceRef(createOp));
143 res.insert(res.end(), selfRes.begin(), selfRes.end());
144 }
145
146 return res;
147}
148
149std::vector<SourceRef> SourceRef::getAllSourceRefs(StructDefOp structDef, FieldDefOp fieldDef) {
150 std::vector<SourceRef> res;
151 FuncDefOp constrainFnOp = structDef.getConstrainFuncOp();
152 ensure(
153 fieldDef->getParentOfType<StructDefOp>() == structDef, "Field " + Twine(fieldDef.getName()) +
154 " is not a field of struct " +
155 Twine(structDef.getName())
156 );
157 FailureOr<ModuleOp> modOp = getRootModule(structDef);
158 ensure(succeeded(modOp), "could not lookup module from struct " + Twine(structDef.getName()));
159
160 // Get the self argument (like `FuncDefOp::getSelfValueFromConstrain()`)
161 BlockArgument self = constrainFnOp.getArguments().front();
162 SourceRef fieldRef = SourceRef(self, {SourceRefIndex(fieldDef)});
163
164 SymbolTableCollection tables;
165 return getAllSourceRefs(tables, modOp.value(), fieldRef);
166}
167
168Type SourceRef::getType() const {
169 if (isConstantFelt()) {
170 return std::get<FeltConstantOp>(*constantVal).getType();
171 } else if (isConstantIndex()) {
172 return std::get<arith::ConstantIndexOp>(*constantVal).getType();
173 } else if (isTemplateConstant()) {
174 return std::get<ConstReadOp>(*constantVal).getType();
175 } else {
176 int array_derefs = 0;
177 int idx = fieldRefs.size() - 1;
178 while (idx >= 0 && fieldRefs[idx].isIndex()) {
179 array_derefs++;
180 idx--;
181 }
182
183 Type currTy = nullptr;
184 if (idx >= 0) {
185 currTy = fieldRefs[idx].getField().getType();
186 } else {
187 currTy = isBlockArgument() ? getBlockArgument().getType() : getCreateStructOp().getType();
188 }
189
190 while (array_derefs > 0) {
191 currTy = dyn_cast<ArrayType>(currTy).getElementType();
192 array_derefs--;
193 }
194 return currTy;
195 }
196}
197
198bool SourceRef::isValidPrefix(const SourceRef &prefix) const {
199 if (isConstant()) {
200 return false;
201 }
202
203 if (root != prefix.root || fieldRefs.size() < prefix.fieldRefs.size()) {
204 return false;
205 }
206 for (size_t i = 0; i < prefix.fieldRefs.size(); i++) {
207 if (fieldRefs[i] != prefix.fieldRefs[i]) {
208 return false;
209 }
210 }
211 return true;
212}
213
214FailureOr<std::vector<SourceRefIndex>> SourceRef::getSuffix(const SourceRef &prefix) const {
215 if (!isValidPrefix(prefix)) {
216 return failure();
217 }
218 std::vector<SourceRefIndex> suffix;
219 for (size_t i = prefix.fieldRefs.size(); i < fieldRefs.size(); i++) {
220 suffix.push_back(fieldRefs[i]);
221 }
222 return suffix;
223}
224
225FailureOr<SourceRef> SourceRef::translate(const SourceRef &prefix, const SourceRef &other) const {
226 if (isConstant()) {
227 return *this;
228 }
229 auto suffix = getSuffix(prefix);
230 if (failed(suffix)) {
231 return failure();
232 }
233
234 auto newSignalUsage = other;
235 newSignalUsage.fieldRefs.insert(newSignalUsage.fieldRefs.end(), suffix->begin(), suffix->end());
236 return newSignalUsage;
237}
238
239std::vector<SourceRef>
240getAllChildren(SymbolTableCollection &tables, ModuleOp /*mod*/, ArrayType arrayTy, SourceRef root) {
241 std::vector<SourceRef> res;
242 // Recurse into arrays by iterating over their elements
243 for (int64_t i = 0; i < arrayTy.getDimSize(0); i++) {
244 SourceRef childRef = root.createChild(SourceRefIndex(i));
245 res.push_back(childRef);
246 }
247
248 return res;
249}
250
251std::vector<SourceRef> getAllChildren(
252 SymbolTableCollection &tables, ModuleOp mod, SymbolLookupResult<StructDefOp> structDefRes,
253 SourceRef root
254) {
255 std::vector<SourceRef> res;
256 // Recurse into struct types by iterating over all their field definitions
257 for (auto f : structDefRes.get().getOps<FieldDefOp>()) {
258 // We want to store the FieldDefOp, but without the possibility of accidentally dropping the
259 // reference, so we need to re-lookup the symbol to create a SymbolLookupResult, which will
260 // manage the external module containing the field defs, if needed.
261 // TODO: It would be nice if we could manage module op references differently
262 // so we don't have to do this.
263 auto structDefCopy = structDefRes;
264 auto fieldLookup = lookupSymbolIn<FieldDefOp>(
265 tables, SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()), std::move(structDefCopy),
266 mod.getOperation()
267 );
268 ensure(succeeded(fieldLookup), "could not get SymbolLookupResult of existing FieldDefOp");
269 SourceRef childRef = root.createChild(SourceRefIndex(fieldLookup.value()));
270 // Make a reference to the current field, regardless of if it is a composite
271 // type or not.
272 res.push_back(childRef);
273 }
274 return res;
275}
276
277std::vector<SourceRef>
278SourceRef::getAllChildren(SymbolTableCollection &tables, ModuleOp mod) const {
279 auto ty = getType();
280 if (auto structTy = dyn_cast<StructType>(ty)) {
281 return llzk::getAllChildren(tables, mod, getStructDef(tables, mod, structTy), *this);
282 } else if (auto arrayType = dyn_cast<ArrayType>(ty)) {
283 return llzk::getAllChildren(tables, mod, arrayType, *this);
284 }
285 // Scalar type, no children
286 return {};
287}
288
289void SourceRef::print(raw_ostream &os) const {
290 if (isConstantFelt()) {
291 os << "<felt.const: " << getConstantFeltValue() << '>';
292 } else if (isConstantIndex()) {
293 os << "<index: " << getConstantIndexValue() << '>';
294 } else if (isTemplateConstant()) {
295 auto constRead = std::get<ConstReadOp>(*constantVal);
296 auto structDefOp = constRead->getParentOfType<StructDefOp>();
297 ensure(structDefOp, "struct template should have a struct parent");
298 os << '@' << structDefOp.getName() << "<[@" << constRead.getConstName() << "]>";
299 } else {
300 if (isCreateStructOp()) {
301 os << "%self";
302 } else {
303 ensure(isBlockArgument(), "unhandled print case");
304 os << "%arg" << getInputNum();
305 }
306
307 for (auto f : fieldRefs) {
308 os << "[" << f << "]";
309 }
310 }
311}
312
313bool SourceRef::operator==(const SourceRef &rhs) const {
314 // This way two felt constants can be equal even if the declared in separate ops.
315 if (isConstantInt() && rhs.isConstantInt()) {
316 DynamicAPInt lhsVal = getConstantValue(), rhsVal = rhs.getConstantValue();
317 return getType() == rhs.getType() && lhsVal == rhsVal;
318 }
319 return (root == rhs.root) && (fieldRefs == rhs.fieldRefs) && (constantVal == rhs.constantVal);
320}
321
322// required for EquivalenceClasses usage
323bool SourceRef::operator<(const SourceRef &rhs) const {
324 if (isConstantFelt() && !rhs.isConstantFelt()) {
325 // Put all constants at the end
326 return false;
327 } else if (!isConstantFelt() && rhs.isConstantFelt()) {
328 return true;
329 } else if (isConstantFelt() && rhs.isConstantFelt()) {
330 DynamicAPInt lhsInt = getConstantFeltValue(), rhsInt = rhs.getConstantFeltValue();
331 return lhsInt < rhsInt;
332 }
333
334 if (isConstantIndex() && !rhs.isConstantIndex()) {
335 // Put all constant indices next at the end
336 return false;
337 } else if (!isConstantIndex() && rhs.isConstantIndex()) {
338 return true;
339 } else if (isConstantIndex() && rhs.isConstantIndex()) {
340 DynamicAPInt lhsVal = getConstantIndexValue(), rhsVal = rhs.getConstantIndexValue();
341 return lhsVal < rhsVal;
342 }
343
344 if (isTemplateConstant() && !rhs.isTemplateConstant()) {
345 // Put all template constants next at the end
346 return false;
347 } else if (!isTemplateConstant() && rhs.isTemplateConstant()) {
348 return true;
349 } else if (isTemplateConstant() && rhs.isTemplateConstant()) {
350 StringRef lhsName = std::get<ConstReadOp>(*constantVal).getConstName();
351 StringRef rhsName = std::get<ConstReadOp>(*rhs.constantVal).getConstName();
352 return lhsName.compare(rhsName) < 0;
353 }
354
355 // Sort out the block argument vs struct.new cases
356 if (isBlockArgument() && rhs.isCreateStructOp()) {
357 return true;
358 } else if (isCreateStructOp() && rhs.isBlockArgument()) {
359 return false;
360 } else if (isBlockArgument() && rhs.isBlockArgument()) {
361 if (getInputNum() < rhs.getInputNum()) {
362 return true;
363 } else if (getInputNum() > rhs.getInputNum()) {
364 return false;
365 }
366 } else if (isCreateStructOp() && rhs.isCreateStructOp()) {
367 CreateStructOp lhsOp = getCreateStructOp(), rhsOp = rhs.getCreateStructOp();
368 if (lhsOp < rhsOp) {
369 return true;
370 } else if (lhsOp > rhsOp) {
371 return false;
372 }
373 } else {
374 llvm_unreachable("unhandled operator< case");
375 }
376
377 for (size_t i = 0; i < fieldRefs.size() && i < rhs.fieldRefs.size(); i++) {
378 if (fieldRefs[i] < rhs.fieldRefs[i]) {
379 return true;
380 } else if (fieldRefs[i] > rhs.fieldRefs[i]) {
381 return false;
382 }
383 }
384 return fieldRefs.size() < rhs.fieldRefs.size();
385}
386
387size_t SourceRef::Hash::operator()(const SourceRef &val) const {
388 if (val.isConstantInt()) {
389 return llvm::hash_value(val.getConstantValue());
390 } else if (val.isTemplateConstant()) {
391 return OpHash<ConstReadOp> {}(std::get<ConstReadOp>(*val.constantVal));
392 } else {
393 ensure(val.isBlockArgument() || val.isCreateStructOp(), "unhandled SourceRef hash case");
394
395 size_t hash = val.isBlockArgument() ? std::hash<unsigned> {}(val.getInputNum())
397 for (auto f : val.fieldRefs) {
398 hash ^= f.getHash();
399 }
400 return hash;
401 }
402}
403
404raw_ostream &operator<<(raw_ostream &os, const SourceRef &rhs) {
405 rhs.print(os);
406 return os;
407}
408
409/* SourceRefSet */
410
412 insert(rhs.begin(), rhs.end());
413 return *this;
414}
415
416raw_ostream &operator<<(raw_ostream &os, const SourceRefSet &rhs) {
417 os << "{ ";
418 std::vector<SourceRef> sortedRefs(rhs.begin(), rhs.end());
419 std::sort(sortedRefs.begin(), sortedRefs.end());
420 for (auto it = sortedRefs.begin(); it != sortedRefs.end();) {
421 os << *it;
422 it++;
423 if (it != sortedRefs.end()) {
424 os << ", ";
425 } else {
426 os << ' ';
427 }
428 }
429 os << '}';
430 return os;
431}
432
433} // namespace llzk
This file implements helper methods for constructing DynamicAPInts.
This file defines methods symbol lookup across LLZK operations and included files.
bool isIndexRange() const
Definition SourceRef.h:67
bool operator<(const SourceRefIndex &rhs) const
Definition SourceRef.cpp:49
component::FieldDefOp getField() const
Definition SourceRef.h:53
bool isIndex() const
Definition SourceRef.h:61
SourceRefIndex(component::FieldDefOp f)
Definition SourceRef.h:40
llvm::DynamicAPInt getIndex() const
Definition SourceRef.h:62
void print(mlir::raw_ostream &os) const
Definition SourceRef.cpp:34
IndexRange getIndexRange() const
Definition SourceRef.h:68
bool isField() const
Definition SourceRef.h:49
SourceRefSet & join(const SourceRefSet &rhs)
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
bool isBlockArgument() const
Definition SourceRef.h:183
llvm::DynamicAPInt getConstantIndexValue() const
Definition SourceRef.h:207
std::vector< SourceRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this SourceRef, assuming this ref is not a scalar.
void print(mlir::raw_ostream &os) const
mlir::FailureOr< std::vector< SourceRefIndex > > getSuffix(const SourceRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool operator==(const SourceRef &rhs) const
bool isConstantFelt() const
Definition SourceRef.h:160
component::CreateStructOp getCreateStructOp() const
Definition SourceRef.h:195
bool isValidPrefix(const SourceRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
bool isConstantIndex() const
Definition SourceRef.h:163
SourceRef createChild(SourceRefIndex r) const
Definition SourceRef.h:252
mlir::BlockArgument getBlockArgument() const
Definition SourceRef.h:186
llvm::DynamicAPInt getConstantValue() const
Definition SourceRef.h:213
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
bool operator<(const SourceRef &rhs) const
SourceRef(mlir::BlockArgument b)
Definition SourceRef.h:143
mlir::FailureOr< SourceRef > translate(const SourceRef &prefix, const SourceRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
bool isTemplateConstant() const
Definition SourceRef.h:167
bool isConstant() const
Definition SourceRef.h:171
llvm::DynamicAPInt getConstantFeltValue() const
Definition SourceRef.h:200
unsigned getInputNum() const
Definition SourceRef.h:190
bool isConstantInt() const
Definition SourceRef.h:172
bool isCreateStructOp() const
Definition SourceRef.h:192
mlir::Type getType() const
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1152
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:427
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:47
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:333
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:772
void ensure(bool condition, const llvm::Twine &errMsg)
FailureOr< ModuleOp > getRootModule(Operation *from)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::vector< SourceRef > getAllChildren(SymbolTableCollection &tables, ModuleOp, ArrayType arrayTy, SourceRef root)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
SymbolLookupResult< StructDefOp > getStructDef(SymbolTableCollection &tables, ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
size_t operator()(const SourceRefIndex &c) const
Definition SourceRef.cpp:72
size_t operator()(const SourceRef &val) const