LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRefLattice.cpp
Go to the documentation of this file.
1//===-- ConstrainRefLattice.cpp - ConstrainRef lattice & utils --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Hash.h"
17
18#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
19#include <mlir/IR/Value.h>
20
21#include <llvm/Support/Debug.h>
22
23#include <numeric>
24#include <unordered_set>
25
26#define DEBUG_TYPE "llzk-constrain-ref-lattice"
27
28namespace llzk {
29
30using namespace component;
31using namespace felt;
32using namespace polymorphic;
33
34/* ConstrainRefLatticeValue */
35
36mlir::ChangeResult ConstrainRefLatticeValue::insert(const ConstrainRef &rhs) {
37 auto rhsVal = ConstrainRefLatticeValue(rhs);
38 if (isScalar()) {
39 return updateScalar(rhsVal.getScalarValue());
40 } else {
41 return foldAndUpdate(rhsVal);
42 }
43}
44
45std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
47 auto newVal = *this;
48 auto res = mlir::ChangeResult::NoChange;
49 if (newVal.isScalar()) {
50 res = newVal.translateScalar(translation);
51 } else {
52 for (auto &elem : newVal.getArrayValue()) {
53 auto [newElem, elemRes] = elem->translate(translation);
54 (*elem) = newElem;
55 res |= elemRes;
56 }
57 }
58 return {newVal, res};
59}
60
61std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
63 ConstrainRefIndex idx(fieldRef);
64 auto transform = [&idx](const ConstrainRef &r) -> ConstrainRef { return r.createChild(idx); };
65 return elementwiseTransform(transform);
66}
67
68std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
69ConstrainRefLatticeValue::extract(const std::vector<ConstrainRefIndex> &indices) const {
70 if (isArray()) {
71 ensure(indices.size() <= getNumArrayDims(), "invalid extract array operands");
72
73 // First, compute what chunk(s) to index
74 std::vector<size_t> currIdxs {0};
75 for (unsigned i = 0; i < indices.size(); i++) {
76 auto &idx = indices[i];
77 auto currDim = getArrayDim(i);
78
79 std::vector<size_t> newIdxs;
80 ensure(idx.isIndex() || idx.isIndexRange(), "wrong type of index for array");
81 if (idx.isIndex()) {
82 auto idxVal = fromAPInt(idx.getIndex());
83 std::transform(
84 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
85 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
86 );
87 } else {
88 auto [low, high] = idx.getIndexRange();
89 for (auto idxVal = fromAPInt(low); idxVal < fromAPInt(high); idxVal++) {
90 std::transform(
91 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
92 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
93 );
94 }
95 }
96
97 currIdxs = newIdxs;
98 }
99 std::vector<int64_t> newArrayDims;
100 size_t chunkSz = 1;
101 for (unsigned i = indices.size(); i < getNumArrayDims(); i++) {
102 auto dim = getArrayDim(i);
103 newArrayDims.push_back(dim);
104 chunkSz *= dim;
105 }
106 if (newArrayDims.empty()) {
107 // read case, where the return value is a scalar (single element)
108 ConstrainRefLatticeValue extractedVal;
109 for (auto idx : currIdxs) {
110 (void)extractedVal.update(getElemFlatIdx(idx));
111 }
112 return {extractedVal, mlir::ChangeResult::Change};
113 } else {
114 // extract case, where the return value is an array of fewer dimensions.
115 ConstrainRefLatticeValue extractedVal(newArrayDims);
116 for (auto chunkStart : currIdxs) {
117 for (size_t i = 0; i < chunkSz; i++) {
118 (void)extractedVal.getElemFlatIdx(i).update(getElemFlatIdx(chunkStart + i));
119 }
120 }
121 return {extractedVal, mlir::ChangeResult::Change};
122 }
123 } else {
124 auto currVal = *this;
125 auto res = mlir::ChangeResult::NoChange;
126 for (auto &idx : indices) {
127 auto transform = [&idx](const ConstrainRef &r) -> ConstrainRef { return r.createChild(idx); };
128 auto [newVal, transformRes] = currVal.elementwiseTransform(transform);
129 currVal = std::move(newVal);
130 res |= transformRes;
131 }
132 return {currVal, res};
133 }
134}
135
136mlir::ChangeResult ConstrainRefLatticeValue::translateScalar(const TranslationMap &translation) {
137 auto res = mlir::ChangeResult::NoChange;
138 // copy the current value
139 auto currVal = getScalarValue();
140 // reset this value
141 getValue() = ScalarTy();
142 for (auto &[ref, val] : translation) {
143 auto it = currVal.find(ref);
144 if (it != currVal.end()) {
145 res |= update(val);
146 }
147 }
148 return res;
149}
150
151std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
153 llvm::function_ref<ConstrainRef(const ConstrainRef &)> transform
154) const {
155 auto newVal = *this;
156 auto res = mlir::ChangeResult::NoChange;
157 if (newVal.isScalar()) {
158 ScalarTy indexed;
159 for (auto &ref : newVal.getScalarValue()) {
160 auto [_, inserted] = indexed.insert(transform(ref));
161 if (inserted) {
162 res |= mlir::ChangeResult::Change;
163 }
164 }
165 newVal.getScalarValue() = indexed;
166 } else {
167 for (auto &elem : newVal.getArrayValue()) {
168 auto [newElem, elemRes] = elem->elementwiseTransform(transform);
169 (*elem) = newElem;
170 res |= elemRes;
171 }
172 }
173 return {newVal, res};
174}
175
176mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefLatticeValue &v) {
177 v.print(os);
178 return os;
179}
180
181/* ConstrainRefLattice */
182
183mlir::FailureOr<ConstrainRef> ConstrainRefLattice::getSourceRef(mlir::Value val) {
184 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(val)) {
185 return ConstrainRef(blockArg);
186 } else if (auto defOp = val.getDefiningOp()) {
187 if (auto feltConst = mlir::dyn_cast<FeltConstantOp>(defOp)) {
188 return ConstrainRef(feltConst);
189 } else if (auto constIdx = mlir::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
190 return ConstrainRef(constIdx);
191 } else if (auto readConst = mlir::dyn_cast<ConstReadOp>(defOp)) {
192 return ConstrainRef(readConst);
193 }
194 }
195 return mlir::failure();
196}
197
198void ConstrainRefLattice::print(mlir::raw_ostream &os) const {
199 os << "ConstrainRefLattice { ";
200 for (auto mit = valMap.begin(); mit != valMap.end();) {
201 auto &[val, latticeVal] = *mit;
202 os << "\n (" << val << ") => " << latticeVal;
203 mit++;
204 if (mit != valMap.end()) {
205 os << ',';
206 } else {
207 os << '\n';
208 }
209 }
210 os << "}\n";
211}
212
213mlir::ChangeResult ConstrainRefLattice::setValues(const ValueMap &rhs) {
214 auto res = mlir::ChangeResult::NoChange;
215
216 for (auto &[v, s] : rhs) {
217 res |= setValue(v, s);
218 }
219 return res;
220}
221
223 auto it = valMap.find(v);
224 if (it != valMap.end()) {
225 return it->second;
226 }
227
228 auto sourceRef = getSourceRef(v);
229 if (mlir::succeeded(sourceRef)) {
230 return ConstrainRefLatticeValue(sourceRef.value());
231 }
233}
234
236 auto op = this->getPoint().get<mlir::Operation *>();
237 if (auto retOp = mlir::dyn_cast<function::ReturnOp>(op)) {
238 if (i >= retOp.getNumOperands()) {
239 llvm::report_fatal_error("return value requested is out of range");
240 }
241 return this->getOrDefault(retOp.getOperand(i));
242 }
244}
245
246llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ConstrainRefLattice &lattice) {
247 lattice.print(os);
248 return os;
249}
250
251} // namespace llzk
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
mlir::ChangeResult insert(const ConstrainRef &rhs)
Directly insert the ref into this value.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > extract(const std::vector< ConstrainRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > referenceField(SymbolLookupResult< component::FieldDefOp > fieldRef) const
Add the given fieldRef to the constrain refs contained within this value.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
virtual std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > elementwiseTransform(llvm::function_ref< ConstrainRef(const ConstrainRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
A lattice for use in dense analysis.
ConstrainRefLatticeValue getOrDefault(mlir::Value v) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseMap< mlir::Value, ConstrainRefLatticeValue > ValueMap
static mlir::FailureOr< ConstrainRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument from the function args or a constant),...
mlir::ChangeResult setValues(const ValueMap &rhs)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult setValue(mlir::Value v, const ConstrainRefLatticeValue &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
void print(mlir::raw_ostream &os) const
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
int64_t fromAPInt(llvm::APInt i)
std::unordered_map< ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash > TranslationMap