LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRefLattice.cpp
Go to the documentation of this file.
1//===-- ConstrainRefLattice.cpp - ConstrainRef lattice & utils --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Hash.h"
17
18#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
19#include <mlir/IR/Value.h>
20
21#include <llvm/Support/Debug.h>
22
23#include <numeric>
24#include <unordered_set>
25
26#define DEBUG_TYPE "llzk-constrain-ref-lattice"
27
28using namespace mlir;
29
30namespace llzk {
31
32using namespace component;
33using namespace felt;
34using namespace polymorphic;
35
36/* ConstrainRefLatticeValue */
37
38mlir::ChangeResult ConstrainRefLatticeValue::insert(const ConstrainRef &rhs) {
39 auto rhsVal = ConstrainRefLatticeValue(rhs);
40 if (isScalar()) {
41 return updateScalar(rhsVal.getScalarValue());
42 } else {
43 return foldAndUpdate(rhsVal);
44 }
45}
46
47std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
49 auto newVal = *this;
50 auto res = mlir::ChangeResult::NoChange;
51 if (newVal.isScalar()) {
52 res = newVal.translateScalar(translation);
53 } else {
54 for (auto &elem : newVal.getArrayValue()) {
55 auto [newElem, elemRes] = elem->translate(translation);
56 (*elem) = newElem;
57 res |= elemRes;
58 }
59 }
60 return {newVal, res};
61}
62
63std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
65 ConstrainRefIndex idx(fieldRef);
66 auto transform = [&idx](const ConstrainRef &r) -> ConstrainRef { return r.createChild(idx); };
67 return elementwiseTransform(transform);
68}
69
70std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
71ConstrainRefLatticeValue::extract(const std::vector<ConstrainRefIndex> &indices) const {
72 if (isArray()) {
73 ensure(indices.size() <= getNumArrayDims(), "invalid extract array operands");
74
75 // First, compute what chunk(s) to index
76 std::vector<size_t> currIdxs {0};
77 for (unsigned i = 0; i < indices.size(); i++) {
78 auto &idx = indices[i];
79 auto currDim = getArrayDim(i);
80
81 std::vector<size_t> newIdxs;
82 ensure(idx.isIndex() || idx.isIndexRange(), "wrong type of index for array");
83 if (idx.isIndex()) {
84 auto idxVal = fromAPInt(idx.getIndex());
85 std::transform(
86 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
87 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
88 );
89 } else {
90 auto [low, high] = idx.getIndexRange();
91 for (auto idxVal = fromAPInt(low); idxVal < fromAPInt(high); idxVal++) {
92 std::transform(
93 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
94 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
95 );
96 }
97 }
98
99 currIdxs = newIdxs;
100 }
101 std::vector<int64_t> newArrayDims;
102 size_t chunkSz = 1;
103 for (unsigned i = indices.size(); i < getNumArrayDims(); i++) {
104 auto dim = getArrayDim(i);
105 newArrayDims.push_back(dim);
106 chunkSz *= dim;
107 }
108 if (newArrayDims.empty()) {
109 // read case, where the return value is a scalar (single element)
110 ConstrainRefLatticeValue extractedVal;
111 for (auto idx : currIdxs) {
112 (void)extractedVal.update(getElemFlatIdx(idx));
113 }
114 return {extractedVal, mlir::ChangeResult::Change};
115 } else {
116 // extract case, where the return value is an array of fewer dimensions.
117 ConstrainRefLatticeValue extractedVal(newArrayDims);
118 for (auto chunkStart : currIdxs) {
119 for (size_t i = 0; i < chunkSz; i++) {
120 (void)extractedVal.getElemFlatIdx(i).update(getElemFlatIdx(chunkStart + i));
121 }
122 }
123 return {extractedVal, mlir::ChangeResult::Change};
124 }
125 } else {
126 auto currVal = *this;
127 auto res = mlir::ChangeResult::NoChange;
128 for (auto &idx : indices) {
129 auto transform = [&idx](const ConstrainRef &r) -> ConstrainRef { return r.createChild(idx); };
130 auto [newVal, transformRes] = currVal.elementwiseTransform(transform);
131 currVal = std::move(newVal);
132 res |= transformRes;
133 }
134 return {currVal, res};
135 }
136}
137
138mlir::ChangeResult ConstrainRefLatticeValue::translateScalar(const TranslationMap &translation) {
139 auto res = mlir::ChangeResult::NoChange;
140 // copy the current value
141 auto currVal = getScalarValue();
142 // reset this value
143 getValue() = ScalarTy();
144 // For each current element, see if the translation map contains a valid prefix.
145 // If so, translate the current element with all replacement prefixes indicated
146 // by the translation value.
147 for (const ConstrainRef &currRef : currVal) {
148 for (auto &[prefix, replacementVal] : translation) {
149 if (currRef.isValidPrefix(prefix)) {
150 for (const ConstrainRef &replacementPrefix : replacementVal.foldToScalar()) {
151 auto translatedRefRes = currRef.translate(prefix, replacementPrefix);
152 if (succeeded(translatedRefRes)) {
153 res |= insert(*translatedRefRes);
154 }
155 }
156 }
157 }
158 }
159 return res;
160}
161
162std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
164 llvm::function_ref<ConstrainRef(const ConstrainRef &)> transform
165) const {
166 auto newVal = *this;
167 auto res = mlir::ChangeResult::NoChange;
168 if (newVal.isScalar()) {
169 ScalarTy indexed;
170 for (auto &ref : newVal.getScalarValue()) {
171 auto [_, inserted] = indexed.insert(transform(ref));
172 if (inserted) {
173 res |= mlir::ChangeResult::Change;
174 }
175 }
176 newVal.getScalarValue() = indexed;
177 } else {
178 for (auto &elem : newVal.getArrayValue()) {
179 auto [newElem, elemRes] = elem->elementwiseTransform(transform);
180 (*elem) = newElem;
181 res |= elemRes;
182 }
183 }
184 return {newVal, res};
185}
186
187mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefLatticeValue &v) {
188 v.print(os);
189 return os;
190}
191
192/* ConstrainRefLattice */
193
194mlir::FailureOr<ConstrainRef> ConstrainRefLattice::getSourceRef(mlir::Value val) {
195 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(val)) {
196 return ConstrainRef(blockArg);
197 } else if (auto defOp = val.getDefiningOp()) {
198 if (auto feltConst = mlir::dyn_cast<FeltConstantOp>(defOp)) {
199 return ConstrainRef(feltConst);
200 } else if (auto constIdx = mlir::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
201 return ConstrainRef(constIdx);
202 } else if (auto readConst = mlir::dyn_cast<ConstReadOp>(defOp)) {
203 return ConstrainRef(readConst);
204 } else if (auto structNew = mlir::dyn_cast<CreateStructOp>(defOp)) {
205 return ConstrainRef(structNew);
206 }
207 }
208 return mlir::failure();
209}
210
211void ConstrainRefLattice::print(mlir::raw_ostream &os) const {
212 os << "ConstrainRefLattice { ";
213 for (auto mit = valMap.begin(); mit != valMap.end();) {
214 auto &[val, latticeVal] = *mit;
215 os << "\n (";
216 if (val.is<Value>()) {
217 os << val.get<Value>();
218 } else if (val.is<Operation *>()) {
219 os << *val.get<Operation *>();
220 } else {
221 llvm_unreachable("unhandled ValueTy print case");
222 }
223 os << ") => " << latticeVal;
224 mit++;
225 if (mit != valMap.end()) {
226 os << ',';
227 } else {
228 os << '\n';
229 }
230 }
231 os << "}\n";
232}
233
234mlir::ChangeResult ConstrainRefLattice::setValues(const ValueMap &rhs) {
235 auto res = mlir::ChangeResult::NoChange;
236
237 for (auto &[v, s] : rhs) {
238 res |= setValue(v, s);
239 }
240 return res;
241}
242
244 for (const ConstrainRef &ref : rhs.foldToScalar()) {
245 refMap[ref].insert(v);
246 }
247 return valMap[v].setValue(rhs);
248}
249
250mlir::ChangeResult ConstrainRefLattice::setValue(ValueTy v, const ConstrainRef &ref) {
251 refMap[ref].insert(v);
252 return valMap[v].setValue(ConstrainRefLatticeValue(ref));
253}
254
256 auto it = valMap.find(v);
257 if (it != valMap.end()) {
258 return it->second;
259 }
260
261 if (v.is<Value>()) {
262 auto sourceRef = getSourceRef(v.get<Value>());
263 if (mlir::succeeded(sourceRef)) {
264 return ConstrainRefLatticeValue(sourceRef.value());
265 }
266 }
268}
269
271 auto op = this->getPoint().get<mlir::Operation *>();
272 if (auto retOp = mlir::dyn_cast<function::ReturnOp>(op)) {
273 if (i >= retOp.getNumOperands()) {
274 llvm::report_fatal_error("return value requested is out of range");
275 }
276 return this->getOrDefault(retOp.getOperand(i));
277 }
279}
280
282 if (auto it = refMap.find(ref); it != refMap.end()) {
283 return it->second;
284 }
286}
287
288llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ConstrainRefLattice &lattice) {
289 lattice.print(os);
290 return os;
291}
292
293} // namespace llzk
294
295namespace llvm {
296
297raw_ostream &operator<<(raw_ostream &os, llvm::PointerUnion<mlir::Value, mlir::Operation *> ptr) {
298 if (ptr.is<Value>()) {
299 os << ptr.get<Value>();
300 } else {
301 Operation *op = ptr.get<Operation *>();
302 if (op) {
303 os << *op;
304 } else {
305 os << "<null operation>";
306 }
307 }
308 return os;
309}
310} // namespace llvm
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
mlir::ChangeResult insert(const ConstrainRef &rhs)
Directly insert the ref into this value.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > extract(const std::vector< ConstrainRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > referenceField(SymbolLookupResult< component::FieldDefOp > fieldRef) const
Add the given fieldRef to the constrain refs contained within this value.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
virtual std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > elementwiseTransform(llvm::function_ref< ConstrainRef(const ConstrainRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
A lattice for use in dense analysis.
mlir::DenseMap< ValueTy, ConstrainRefLatticeValue > ValueMap
ValueSet lookupValues(const ConstrainRef &r) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseSet< ValueTy > ValueSet
static mlir::FailureOr< ConstrainRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument from the function args or a constant),...
mlir::ChangeResult setValues(const ValueMap &rhs)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult setValue(ValueTy v, const ConstrainRefLatticeValue &rhs)
ConstrainRefLatticeValue getOrDefault(ValueTy v) const
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
Defines a reference to a llzk object within a constrain function call.
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
void print(mlir::raw_ostream &os) const
raw_ostream & operator<<(raw_ostream &os, llvm::PointerUnion< mlir::Value, mlir::Operation * > ptr)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
raw_ostream & operator<<(raw_ostream &os, const ConstrainRef &rhs)
int64_t fromAPInt(llvm::APInt i)
std::unordered_map< ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash > TranslationMap