18#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
19#include <mlir/IR/Value.h>
21#include <llvm/Support/Debug.h>
24#include <unordered_set>
26#define DEBUG_TYPE "llzk-constrain-ref-lattice"
32using namespace component;
34using namespace polymorphic;
47std::pair<SourceRefLatticeValue, mlir::ChangeResult>
50 auto res = mlir::ChangeResult::NoChange;
51 if (newVal.isScalar()) {
52 res = newVal.translateScalar(translation);
54 for (
auto &elem : newVal.getArrayValue()) {
55 auto [newElem, elemRes] = elem->translate(translation);
63std::pair<SourceRefLatticeValue, mlir::ChangeResult>
66 auto transform = [&idx](
const SourceRef &r) ->
SourceRef {
return r.createChild(idx); };
70std::pair<SourceRefLatticeValue, mlir::ChangeResult>
76 std::vector<size_t> currIdxs {0};
77 for (
unsigned i = 0; i < indices.size(); i++) {
78 auto &idx = indices[i];
81 std::vector<size_t> newIdxs;
82 ensure(idx.isIndex() || idx.isIndexRange(),
"wrong type of index for array");
84 int64_t idxVal(idx.getIndex());
86 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
87 [&currDim, &idxVal](
size_t j) { return j * currDim + idxVal; }
90 auto [low, high] = idx.getIndexRange();
91 int64_t lowInt(low), highInt(high);
92 for (int64_t idxVal = lowInt; idxVal < highInt; idxVal++) {
94 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
95 [&currDim, &idxVal](
size_t j) { return j * currDim + idxVal; }
102 std::vector<int64_t> newArrayDims;
106 newArrayDims.push_back(dim);
109 if (newArrayDims.empty()) {
112 for (
auto idx : currIdxs) {
115 return {extractedVal, mlir::ChangeResult::Change};
119 for (
auto chunkStart : currIdxs) {
120 for (
size_t i = 0; i < chunkSz; i++) {
124 return {extractedVal, mlir::ChangeResult::Change};
127 auto currVal = *
this;
128 auto res = mlir::ChangeResult::NoChange;
129 for (
auto &idx : indices) {
130 auto transform = [&idx](
const SourceRef &r) ->
SourceRef {
return r.createChild(idx); };
131 auto [newVal, transformRes] = currVal.elementwiseTransform(transform);
132 currVal = std::move(newVal);
135 return {currVal, res};
140 auto res = mlir::ChangeResult::NoChange;
148 for (
const SourceRef &currRef : currVal) {
149 for (
auto &[prefix, replacementVal] : translation) {
150 if (currRef.isValidPrefix(prefix)) {
151 for (
const SourceRef &replacementPrefix : replacementVal.foldToScalar()) {
152 auto translatedRefRes = currRef.translate(prefix, replacementPrefix);
153 if (succeeded(translatedRefRes)) {
154 res |=
insert(*translatedRefRes);
167 auto res = mlir::ChangeResult::NoChange;
168 if (newVal.isScalar()) {
170 for (
auto &ref : newVal.getScalarValue()) {
171 auto [_, inserted] = indexed.insert(transform(ref));
173 res |= mlir::ChangeResult::Change;
176 newVal.getScalarValue() = indexed;
178 for (
auto &elem : newVal.getArrayValue()) {
179 auto [newElem, elemRes] = elem->elementwiseTransform(transform);
184 return {newVal, res};
195 if (
auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(val)) {
197 }
else if (
auto defOp = val.getDefiningOp()) {
198 if (
auto feltConst = llvm::dyn_cast<FeltConstantOp>(defOp)) {
200 }
else if (
auto constIdx = llvm::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
202 }
else if (
auto readConst = llvm::dyn_cast<ConstReadOp>(defOp)) {
204 }
else if (
auto structNew = llvm::dyn_cast<CreateStructOp>(defOp)) {
208 return mlir::failure();
212 os <<
"SourceRefLattice { ";
213 for (
auto mit = valMap.begin(); mit != valMap.end();) {
214 auto &[val, latticeVal] = *mit;
216 if (
auto asVal = llvm::dyn_cast<Value>(val)) {
218 }
else if (
auto asOp = llvm::dyn_cast<Operation *>(val)) {
221 llvm_unreachable(
"unhandled ValueTy print case");
223 os <<
") => " << latticeVal;
225 if (mit != valMap.end()) {
235 auto res = mlir::ChangeResult::NoChange;
236 for (
auto &[v, s] : rhs) {
244 refMap[ref].insert(v);
246 return valMap[v].setValue(rhs);
250 refMap[ref].insert(v);
255 auto it = valMap.find(v);
256 if (it != valMap.end()) {
260 if (
auto asVal = llvm::dyn_cast_if_present<Value>(v)) {
262 if (mlir::succeeded(sourceRef)) {
270 ProgramPoint *pp = llvm::cast<ProgramPoint *>(this->getAnchor());
271 if (
auto retOp = mlir::dyn_cast_if_present<function::ReturnOp>(pp->getPrevOp())) {
272 if (i >= retOp.getNumOperands()) {
273 llvm::report_fatal_error(
"return value requested is out of range");
281 if (
auto it = refMap.find(ref); it != refMap.end()) {
296raw_ostream &
operator<<(raw_ostream &os, llvm::PointerUnion<mlir::Value, mlir::Operation *> ptr) {
297 if (
auto asVal = llvm::dyn_cast_if_present<Value>(ptr)) {
299 }
else if (
auto asOp = llvm::dyn_cast_if_present<Operation *>(ptr)) {
302 os <<
"<<null PointerUnion>>";
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
virtual std::pair< SourceRefLatticeValue, mlir::ChangeResult > elementwiseTransform(llvm::function_ref< SourceRef(const SourceRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
std::pair< SourceRefLatticeValue, mlir::ChangeResult > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > referenceField(SymbolLookupResult< component::FieldDefOp > fieldRef) const
Add the given fieldRef to the SourceRefs contained within this value.
mlir::ChangeResult insert(const SourceRef &rhs)
Directly insert the ref into this value.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
SourceRefLatticeValue(ScalarTy s)
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
A lattice for use in dense analysis.
mlir::DenseMap< ValueTy, SourceRefLatticeValue > ValueMap
mlir::DenseSet< ValueTy > ValueSet
void print(mlir::raw_ostream &os) const override
static mlir::FailureOr< SourceRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument from the function args or a constant),...
ValueSet lookupValues(const SourceRef &r) const
mlir::ChangeResult setValues(const ValueMap &rhs)
SourceRefLatticeValue getOrDefault(ValueTy v) const
mlir::ChangeResult setValue(ValueTy v, const SourceRefLatticeValue &rhs)
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
SourceRefLatticeValue getReturnValue(unsigned i) const
A reference to a "source", which is the base value from which other SSA values are derived.
size_t getNumArrayDims() const
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
int64_t getArrayDim(unsigned i) const
std::variant< ScalarTy, ArrayTy > & getValue()
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const SourceRefLatticeValue &rhs)
const SourceRefLatticeValue & getElemFlatIdx(unsigned i) const
const ScalarTy & getScalarValue() const
void print(mlir::raw_ostream &os) const
raw_ostream & operator<<(raw_ostream &os, llvm::PointerUnion< mlir::Value, mlir::Operation * > ptr)
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::unordered_map< SourceRef, SourceRefLatticeValue, SourceRef::Hash > TranslationMap