LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRefLattice.cpp
Go to the documentation of this file.
1//===-- ConstrainRefLattice.cpp - ConstrainRef lattice & utils --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15#include "llzk/Util/Hash.h"
17
18#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
19#include <mlir/IR/Value.h>
20
21#include <llvm/Support/Debug.h>
22
23#include <numeric>
24#include <unordered_set>
25
26#define DEBUG_TYPE "llzk-constrain-ref-lattice"
27
28namespace llzk {
29
30using namespace component;
31using namespace felt;
32using namespace polymorphic;
33
34/* ConstrainRefLatticeValue */
35
36mlir::ChangeResult ConstrainRefLatticeValue::insert(const ConstrainRef &rhs) {
37 auto rhsVal = ConstrainRefLatticeValue(rhs);
38 if (isScalar()) {
39 return updateScalar(rhsVal.getScalarValue());
40 } else {
41 return foldAndUpdate(rhsVal);
42 }
43}
44
45std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
47 auto newVal = *this;
48 auto res = mlir::ChangeResult::NoChange;
49 if (newVal.isScalar()) {
50 res = newVal.translateScalar(translation);
51 } else {
52 for (auto &elem : newVal.getArrayValue()) {
53 auto [newElem, elemRes] = elem->translate(translation);
54 (*elem) = newElem;
55 res |= elemRes;
56 }
57 }
58 return {newVal, res};
59}
60
61std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
63 ConstrainRefIndex idx(fieldRef);
64 auto transform = [&idx](const ConstrainRef &r) -> ConstrainRef { return r.createChild(idx); };
65 return elementwiseTransform(transform);
66}
67
68std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
69ConstrainRefLatticeValue::extract(const std::vector<ConstrainRefIndex> &indices) const {
70 if (isArray()) {
71 ensure(indices.size() <= getNumArrayDims(), "invalid extract array operands");
72
73 // First, compute what chunk(s) to index
74 std::vector<size_t> currIdxs {0};
75 for (unsigned i = 0; i < indices.size(); i++) {
76 auto &idx = indices[i];
77 auto currDim = getArrayDim(i);
78
79 std::vector<size_t> newIdxs;
80 ensure(idx.isIndex() || idx.isIndexRange(), "wrong type of index for array");
81 if (idx.isIndex()) {
82 auto idxVal = fromAPInt(idx.getIndex());
83 std::transform(
84 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
85 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
86 );
87 } else {
88 auto [low, high] = idx.getIndexRange();
89 for (auto idxVal = fromAPInt(low); idxVal < fromAPInt(high); idxVal++) {
90 std::transform(
91 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
92 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
93 );
94 }
95 }
96
97 currIdxs = newIdxs;
98 }
99 std::vector<int64_t> newArrayDims;
100 size_t chunkSz = 1;
101 for (unsigned i = indices.size(); i < getNumArrayDims(); i++) {
102 auto dim = getArrayDim(i);
103 newArrayDims.push_back(dim);
104 chunkSz *= dim;
105 }
106 auto extractedVal = ConstrainRefLatticeValue(newArrayDims);
107 for (auto chunkStart : currIdxs) {
108 for (size_t i = 0; i < chunkSz; i++) {
109 (void)extractedVal.getElemFlatIdx(i).update(getElemFlatIdx(chunkStart + i));
110 }
111 }
112
113 return {extractedVal, mlir::ChangeResult::Change};
114 } else {
115 auto currVal = *this;
116 auto res = mlir::ChangeResult::NoChange;
117 for (auto &idx : indices) {
118 auto transform = [&idx](const ConstrainRef &r) -> ConstrainRef { return r.createChild(idx); };
119 auto [newVal, transformRes] = currVal.elementwiseTransform(transform);
120 currVal = std::move(newVal);
121 res |= transformRes;
122 }
123 return {currVal, res};
124 }
125}
126
127mlir::ChangeResult ConstrainRefLatticeValue::translateScalar(const TranslationMap &translation) {
128 auto res = mlir::ChangeResult::NoChange;
129 // copy the current value
130 auto currVal = getScalarValue();
131 // reset this value
132 getValue() = ScalarTy();
133 for (auto &[ref, val] : translation) {
134 auto it = currVal.find(ref);
135 if (it != currVal.end()) {
136 res |= update(val);
137 }
138 }
139 return res;
140}
141
142std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
144 llvm::function_ref<ConstrainRef(const ConstrainRef &)> transform
145) const {
146 auto newVal = *this;
147 auto res = mlir::ChangeResult::NoChange;
148 if (newVal.isScalar()) {
149 ScalarTy indexed;
150 for (auto &ref : newVal.getScalarValue()) {
151 auto [_, inserted] = indexed.insert(transform(ref));
152 if (inserted) {
153 res |= mlir::ChangeResult::Change;
154 }
155 }
156 newVal.getScalarValue() = indexed;
157 } else {
158 for (auto &elem : newVal.getArrayValue()) {
159 auto [newElem, elemRes] = elem->elementwiseTransform(transform);
160 (*elem) = newElem;
161 res |= elemRes;
162 }
163 }
164 return {newVal, res};
165}
166
167mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefLatticeValue &v) {
168 v.print(os);
169 return os;
170}
171
172/* ConstrainRefLattice */
173
174mlir::FailureOr<ConstrainRef> ConstrainRefLattice::getSourceRef(mlir::Value val) {
175 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(val)) {
176 return ConstrainRef(blockArg);
177 } else if (auto defOp = val.getDefiningOp()) {
178 if (auto feltConst = mlir::dyn_cast<FeltConstantOp>(defOp)) {
179 return ConstrainRef(feltConst);
180 } else if (auto constIdx = mlir::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
181 return ConstrainRef(constIdx);
182 } else if (auto readConst = mlir::dyn_cast<ConstReadOp>(defOp)) {
183 return ConstrainRef(readConst);
184 }
185 }
186 return mlir::failure();
187}
188
189void ConstrainRefLattice::print(mlir::raw_ostream &os) const {
190 os << "ConstrainRefLattice { ";
191 for (auto mit = valMap.begin(); mit != valMap.end();) {
192 auto &[val, latticeVal] = *mit;
193 os << "\n (" << val << ") => " << latticeVal;
194 mit++;
195 if (mit != valMap.end()) {
196 os << ',';
197 } else {
198 os << '\n';
199 }
200 }
201 os << "}\n";
202}
203
204mlir::ChangeResult ConstrainRefLattice::setValues(const ValueMap &rhs) {
205 auto res = mlir::ChangeResult::NoChange;
206
207 for (auto &[v, s] : rhs) {
208 res |= setValue(v, s);
209 }
210 return res;
211}
212
214 auto it = valMap.find(v);
215 if (it == valMap.end()) {
216 auto sourceRef = getSourceRef(v);
217 if (mlir::succeeded(sourceRef)) {
218 return ConstrainRefLatticeValue(sourceRef.value());
219 }
221 }
222 return it->second;
223}
224
226 auto op = this->getPoint().get<mlir::Operation *>();
227 if (auto retOp = mlir::dyn_cast<function::ReturnOp>(op)) {
228 if (i >= retOp.getNumOperands()) {
229 llvm::report_fatal_error("return value requested is out of range");
230 }
231 return this->getOrDefault(retOp.getOperand(i));
232 }
234}
235
236llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ConstrainRefLattice &lattice) {
237 lattice.print(os);
238 return os;
239}
240
241} // namespace llzk
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
mlir::ChangeResult insert(const ConstrainRef &rhs)
Directly insert the ref into this value.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > extract(const std::vector< ConstrainRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > referenceField(SymbolLookupResult< component::FieldDefOp > fieldRef) const
Add the given fieldRef to the constrain refs contained within this value.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
virtual std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > elementwiseTransform(llvm::function_ref< ConstrainRef(const ConstrainRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
A lattice for use in dense analysis.
ConstrainRefLatticeValue getOrDefault(mlir::Value v) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseMap< mlir::Value, ConstrainRefLatticeValue > ValueMap
static mlir::FailureOr< ConstrainRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument from the function args or a constant),...
mlir::ChangeResult setValues(const ValueMap &rhs)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult setValue(mlir::Value v, const ConstrainRefLatticeValue &rhs)
Defines a reference to a llzk object within a constrain function call.
void print(mlir::raw_ostream &os) const
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
int64_t fromAPInt(llvm::APInt i)
std::unordered_map< ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash > TranslationMap