18#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
19#include <mlir/IR/Value.h>
21#include <llvm/Support/Debug.h>
24#include <unordered_set>
26#define DEBUG_TYPE "llzk-constrain-ref-lattice"
32using namespace component;
34using namespace polymorphic;
47std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
50 auto res = mlir::ChangeResult::NoChange;
51 if (newVal.isScalar()) {
52 res = newVal.translateScalar(translation);
54 for (
auto &elem : newVal.getArrayValue()) {
55 auto [newElem, elemRes] = elem->translate(translation);
63std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
70std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
76 std::vector<size_t> currIdxs {0};
77 for (
unsigned i = 0; i < indices.size(); i++) {
78 auto &idx = indices[i];
81 std::vector<size_t> newIdxs;
82 ensure(idx.isIndex() || idx.isIndexRange(),
"wrong type of index for array");
86 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
87 [&currDim, &idxVal](
size_t j) { return j * currDim + idxVal; }
90 auto [low, high] = idx.getIndexRange();
93 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
94 [&currDim, &idxVal](
size_t j) { return j * currDim + idxVal; }
101 std::vector<int64_t> newArrayDims;
105 newArrayDims.push_back(dim);
108 if (newArrayDims.empty()) {
111 for (
auto idx : currIdxs) {
114 return {extractedVal, mlir::ChangeResult::Change};
118 for (
auto chunkStart : currIdxs) {
119 for (
size_t i = 0; i < chunkSz; i++) {
123 return {extractedVal, mlir::ChangeResult::Change};
126 auto currVal = *
this;
127 auto res = mlir::ChangeResult::NoChange;
128 for (
auto &idx : indices) {
130 auto [newVal, transformRes] = currVal.elementwiseTransform(transform);
131 currVal = std::move(newVal);
134 return {currVal, res};
139 auto res = mlir::ChangeResult::NoChange;
148 for (
auto &[prefix, replacementVal] : translation) {
149 if (currRef.isValidPrefix(prefix)) {
150 for (
const ConstrainRef &replacementPrefix : replacementVal.foldToScalar()) {
151 auto translatedRefRes = currRef.translate(prefix, replacementPrefix);
152 if (succeeded(translatedRefRes)) {
153 res |=
insert(*translatedRefRes);
162std::pair<ConstrainRefLatticeValue, mlir::ChangeResult>
167 auto res = mlir::ChangeResult::NoChange;
168 if (newVal.isScalar()) {
170 for (
auto &ref : newVal.getScalarValue()) {
171 auto [_, inserted] = indexed.insert(transform(ref));
173 res |= mlir::ChangeResult::Change;
176 newVal.getScalarValue() = indexed;
178 for (
auto &elem : newVal.getArrayValue()) {
179 auto [newElem, elemRes] = elem->elementwiseTransform(transform);
184 return {newVal, res};
195 if (
auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(val)) {
197 }
else if (
auto defOp = val.getDefiningOp()) {
198 if (
auto feltConst = mlir::dyn_cast<FeltConstantOp>(defOp)) {
200 }
else if (
auto constIdx = mlir::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
202 }
else if (
auto readConst = mlir::dyn_cast<ConstReadOp>(defOp)) {
204 }
else if (
auto structNew = mlir::dyn_cast<CreateStructOp>(defOp)) {
208 return mlir::failure();
212 os <<
"ConstrainRefLattice { ";
213 for (
auto mit = valMap.begin(); mit != valMap.end();) {
214 auto &[val, latticeVal] = *mit;
216 if (val.is<Value>()) {
217 os << val.get<Value>();
218 }
else if (val.is<Operation *>()) {
219 os << *val.get<Operation *>();
221 llvm_unreachable(
"unhandled ValueTy print case");
223 os <<
") => " << latticeVal;
225 if (mit != valMap.end()) {
235 auto res = mlir::ChangeResult::NoChange;
237 for (
auto &[v, s] : rhs) {
245 refMap[ref].insert(v);
247 return valMap[v].setValue(rhs);
251 refMap[ref].insert(v);
256 auto it = valMap.find(v);
257 if (it != valMap.end()) {
263 if (mlir::succeeded(sourceRef)) {
271 auto op = this->getPoint().get<mlir::Operation *>();
272 if (
auto retOp = mlir::dyn_cast<function::ReturnOp>(op)) {
273 if (i >= retOp.getNumOperands()) {
274 llvm::report_fatal_error(
"return value requested is out of range");
282 if (
auto it = refMap.find(ref); it != refMap.end()) {
297raw_ostream &
operator<<(raw_ostream &os, llvm::PointerUnion<mlir::Value, mlir::Operation *> ptr) {
298 if (ptr.is<Value>()) {
299 os << ptr.get<Value>();
301 Operation *op = ptr.get<Operation *>();
305 os <<
"<null operation>";
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
mlir::ChangeResult insert(const ConstrainRef &rhs)
Directly insert the ref into this value.
ConstrainRefLatticeValue(ScalarTy s)
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > extract(const std::vector< ConstrainRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > referenceField(SymbolLookupResult< component::FieldDefOp > fieldRef) const
Add the given fieldRef to the constrain refs contained within this value.
ConstrainRefLatticeValue()
std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
virtual std::pair< ConstrainRefLatticeValue, mlir::ChangeResult > elementwiseTransform(llvm::function_ref< ConstrainRef(const ConstrainRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
A lattice for use in dense analysis.
mlir::DenseMap< ValueTy, ConstrainRefLatticeValue > ValueMap
ValueSet lookupValues(const ConstrainRef &r) const
ConstrainRefLatticeValue getReturnValue(unsigned i) const
mlir::DenseSet< ValueTy > ValueSet
static mlir::FailureOr< ConstrainRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument from the function args or a constant),...
mlir::ChangeResult setValues(const ValueMap &rhs)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult setValue(ValueTy v, const ConstrainRefLatticeValue &rhs)
ConstrainRefLatticeValue getOrDefault(ValueTy v) const
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
Defines a reference to a llzk object within a constrain function call.
size_t getNumArrayDims() const
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
int64_t getArrayDim(unsigned i) const
std::variant< ScalarTy, ArrayTy > & getValue()
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const ConstrainRefLatticeValue &rhs)
const ConstrainRefLatticeValue & getElemFlatIdx(unsigned i) const
const ScalarTy & getScalarValue() const
void print(mlir::raw_ostream &os) const
raw_ostream & operator<<(raw_ostream &os, llvm::PointerUnion< mlir::Value, mlir::Operation * > ptr)
void ensure(bool condition, llvm::Twine errMsg)
raw_ostream & operator<<(raw_ostream &os, const ConstrainRef &rhs)
int64_t fromAPInt(llvm::APInt i)
std::unordered_map< ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash > TranslationMap