LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SourceRefLattice.cpp
Go to the documentation of this file.
1//===-- SourceRefLattice.cpp - SourceRef lattice & utils --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Hash.h"
17
18#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
19#include <mlir/IR/Value.h>
20
21#include <llvm/Support/Debug.h>
22
23#include <numeric>
24#include <unordered_set>
25
26#define DEBUG_TYPE "llzk-constrain-ref-lattice"
27
28using namespace mlir;
29
30namespace llzk {
31
32using namespace component;
33using namespace felt;
34using namespace polymorphic;
35
36/* SourceRefLatticeValue */
37
38mlir::ChangeResult SourceRefLatticeValue::insert(const SourceRef &rhs) {
39 auto rhsVal = SourceRefLatticeValue(rhs);
40 if (isScalar()) {
41 return updateScalar(rhsVal.getScalarValue());
42 } else {
43 return foldAndUpdate(rhsVal);
44 }
45}
46
47std::pair<SourceRefLatticeValue, mlir::ChangeResult>
49 auto newVal = *this;
50 auto res = mlir::ChangeResult::NoChange;
51 if (newVal.isScalar()) {
52 res = newVal.translateScalar(translation);
53 } else {
54 for (auto &elem : newVal.getArrayValue()) {
55 auto [newElem, elemRes] = elem->translate(translation);
56 (*elem) = newElem;
57 res |= elemRes;
58 }
59 }
60 return {newVal, res};
61}
62
63std::pair<SourceRefLatticeValue, mlir::ChangeResult>
65 SourceRefIndex idx(fieldRef);
66 auto transform = [&idx](const SourceRef &r) -> SourceRef { return r.createChild(idx); };
67 return elementwiseTransform(transform);
68}
69
70std::pair<SourceRefLatticeValue, mlir::ChangeResult>
71SourceRefLatticeValue::extract(const std::vector<SourceRefIndex> &indices) const {
72 if (isArray()) {
73 ensure(indices.size() <= getNumArrayDims(), "invalid extract array operands");
74
75 // First, compute what chunk(s) to index
76 std::vector<size_t> currIdxs {0};
77 for (unsigned i = 0; i < indices.size(); i++) {
78 auto &idx = indices[i];
79 auto currDim = getArrayDim(i);
80
81 std::vector<size_t> newIdxs;
82 ensure(idx.isIndex() || idx.isIndexRange(), "wrong type of index for array");
83 if (idx.isIndex()) {
84 int64_t idxVal(idx.getIndex());
85 std::transform(
86 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
87 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
88 );
89 } else {
90 auto [low, high] = idx.getIndexRange();
91 int64_t lowInt(low), highInt(high);
92 for (int64_t idxVal = lowInt; idxVal < highInt; idxVal++) {
93 std::transform(
94 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
95 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
96 );
97 }
98 }
99
100 currIdxs = newIdxs;
101 }
102 std::vector<int64_t> newArrayDims;
103 size_t chunkSz = 1;
104 for (unsigned i = indices.size(); i < getNumArrayDims(); i++) {
105 auto dim = getArrayDim(i);
106 newArrayDims.push_back(dim);
107 chunkSz *= dim;
108 }
109 if (newArrayDims.empty()) {
110 // read case, where the return value is a scalar (single element)
111 SourceRefLatticeValue extractedVal;
112 for (auto idx : currIdxs) {
113 (void)extractedVal.update(getElemFlatIdx(idx));
114 }
115 return {extractedVal, mlir::ChangeResult::Change};
116 } else {
117 // extract case, where the return value is an array of fewer dimensions.
118 SourceRefLatticeValue extractedVal(newArrayDims);
119 for (auto chunkStart : currIdxs) {
120 for (size_t i = 0; i < chunkSz; i++) {
121 (void)extractedVal.getElemFlatIdx(i).update(getElemFlatIdx(chunkStart + i));
122 }
123 }
124 return {extractedVal, mlir::ChangeResult::Change};
125 }
126 } else {
127 auto currVal = *this;
128 auto res = mlir::ChangeResult::NoChange;
129 for (auto &idx : indices) {
130 auto transform = [&idx](const SourceRef &r) -> SourceRef { return r.createChild(idx); };
131 auto [newVal, transformRes] = currVal.elementwiseTransform(transform);
132 currVal = std::move(newVal);
133 res |= transformRes;
134 }
135 return {currVal, res};
136 }
137}
138
139mlir::ChangeResult SourceRefLatticeValue::translateScalar(const TranslationMap &translation) {
140 auto res = mlir::ChangeResult::NoChange;
141 // copy the current value
142 auto currVal = getScalarValue();
143 // reset this value
144 getValue() = ScalarTy();
145 // For each current element, see if the translation map contains a valid prefix.
146 // If so, translate the current element with all replacement prefixes indicated
147 // by the translation value.
148 for (const SourceRef &currRef : currVal) {
149 for (auto &[prefix, replacementVal] : translation) {
150 if (currRef.isValidPrefix(prefix)) {
151 for (const SourceRef &replacementPrefix : replacementVal.foldToScalar()) {
152 auto translatedRefRes = currRef.translate(prefix, replacementPrefix);
153 if (succeeded(translatedRefRes)) {
154 res |= insert(*translatedRefRes);
155 }
156 }
157 }
158 }
159 }
160 return res;
161}
162
163std::pair<SourceRefLatticeValue, mlir::ChangeResult> SourceRefLatticeValue::elementwiseTransform(
164 llvm::function_ref<SourceRef(const SourceRef &)> transform
165) const {
166 auto newVal = *this;
167 auto res = mlir::ChangeResult::NoChange;
168 if (newVal.isScalar()) {
169 ScalarTy indexed;
170 for (auto &ref : newVal.getScalarValue()) {
171 auto [_, inserted] = indexed.insert(transform(ref));
172 if (inserted) {
173 res |= mlir::ChangeResult::Change;
174 }
175 }
176 newVal.getScalarValue() = indexed;
177 } else {
178 for (auto &elem : newVal.getArrayValue()) {
179 auto [newElem, elemRes] = elem->elementwiseTransform(transform);
180 (*elem) = newElem;
181 res |= elemRes;
182 }
183 }
184 return {newVal, res};
185}
186
187mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefLatticeValue &v) {
188 v.print(os);
189 return os;
190}
191
192/* SourceRefLattice */
193
194mlir::FailureOr<SourceRef> SourceRefLattice::getSourceRef(mlir::Value val) {
195 if (auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(val)) {
196 return SourceRef(blockArg);
197 } else if (auto defOp = val.getDefiningOp()) {
198 if (auto feltConst = llvm::dyn_cast<FeltConstantOp>(defOp)) {
199 return SourceRef(feltConst);
200 } else if (auto constIdx = llvm::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
201 return SourceRef(constIdx);
202 } else if (auto readConst = llvm::dyn_cast<ConstReadOp>(defOp)) {
203 return SourceRef(readConst);
204 } else if (auto structNew = llvm::dyn_cast<CreateStructOp>(defOp)) {
205 return SourceRef(structNew);
206 }
207 }
208 return mlir::failure();
209}
210
211void SourceRefLattice::print(mlir::raw_ostream &os) const {
212 os << "SourceRefLattice { ";
213 for (auto mit = valMap.begin(); mit != valMap.end();) {
214 auto &[val, latticeVal] = *mit;
215 os << "\n (";
216 if (auto asVal = llvm::dyn_cast<Value>(val)) {
217 os << asVal;
218 } else if (auto asOp = llvm::dyn_cast<Operation *>(val)) {
219 os << *asOp;
220 } else {
221 llvm_unreachable("unhandled ValueTy print case");
222 }
223 os << ") => " << latticeVal;
224 mit++;
225 if (mit != valMap.end()) {
226 os << ',';
227 } else {
228 os << '\n';
229 }
230 }
231 os << "}\n";
232}
233
234mlir::ChangeResult SourceRefLattice::setValues(const ValueMap &rhs) {
235 auto res = mlir::ChangeResult::NoChange;
236 for (auto &[v, s] : rhs) {
237 res |= setValue(v, s);
238 }
239 return res;
240}
241
243 for (const SourceRef &ref : rhs.foldToScalar()) {
244 refMap[ref].insert(v);
245 }
246 return valMap[v].setValue(rhs);
247}
248
249mlir::ChangeResult SourceRefLattice::setValue(ValueTy v, const SourceRef &ref) {
250 refMap[ref].insert(v);
251 return valMap[v].setValue(SourceRefLatticeValue(ref));
252}
253
255 auto it = valMap.find(v);
256 if (it != valMap.end()) {
257 return it->second;
258 }
259
260 if (auto asVal = llvm::dyn_cast_if_present<Value>(v)) {
261 auto sourceRef = getSourceRef(asVal);
262 if (mlir::succeeded(sourceRef)) {
263 return SourceRefLatticeValue(sourceRef.value());
264 }
265 }
266 return SourceRefLatticeValue();
267}
268
270 ProgramPoint *pp = llvm::cast<ProgramPoint *>(this->getAnchor());
271 if (auto retOp = mlir::dyn_cast_if_present<function::ReturnOp>(pp->getPrevOp())) {
272 if (i >= retOp.getNumOperands()) {
273 llvm::report_fatal_error("return value requested is out of range");
274 }
275 return this->getOrDefault(retOp.getOperand(i));
276 }
277 return SourceRefLatticeValue();
278}
279
281 if (auto it = refMap.find(ref); it != refMap.end()) {
282 return it->second;
283 }
285}
286
287llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const SourceRefLattice &lattice) {
288 lattice.print(os);
289 return os;
290}
291
292} // namespace llzk
293
294namespace llvm {
295
296raw_ostream &operator<<(raw_ostream &os, llvm::PointerUnion<mlir::Value, mlir::Operation *> ptr) {
297 if (auto asVal = llvm::dyn_cast_if_present<Value>(ptr)) {
298 os << asVal;
299 } else if (auto asOp = llvm::dyn_cast_if_present<Operation *>(ptr)) {
300 os << *asOp;
301 } else {
302 os << "<<null PointerUnion>>";
303 }
304 return os;
305}
306} // namespace llvm
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
Defines an index into an LLZK object.
Definition SourceRef.h:36
A value at a given point of the SourceRefLattice.
virtual std::pair< SourceRefLatticeValue, mlir::ChangeResult > elementwiseTransform(llvm::function_ref< SourceRef(const SourceRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
std::pair< SourceRefLatticeValue, mlir::ChangeResult > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > referenceField(SymbolLookupResult< component::FieldDefOp > fieldRef) const
Add the given fieldRef to the SourceRefs contained within this value.
mlir::ChangeResult insert(const SourceRef &rhs)
Directly insert the ref into this value.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
A lattice for use in dense analysis.
mlir::DenseMap< ValueTy, SourceRefLatticeValue > ValueMap
mlir::DenseSet< ValueTy > ValueSet
void print(mlir::raw_ostream &os) const override
static mlir::FailureOr< SourceRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument from the function args or a constant),...
ValueSet lookupValues(const SourceRef &r) const
mlir::ChangeResult setValues(const ValueMap &rhs)
SourceRefLatticeValue getOrDefault(ValueTy v) const
mlir::ChangeResult setValue(ValueTy v, const SourceRefLatticeValue &rhs)
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
SourceRefLatticeValue getReturnValue(unsigned i) const
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
void print(mlir::raw_ostream &os) const
raw_ostream & operator<<(raw_ostream &os, llvm::PointerUnion< mlir::Value, mlir::Operation * > ptr)
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::unordered_map< SourceRef, SourceRefLatticeValue, SourceRef::Hash > TranslationMap