LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
AbstractLatticeValue.h
Go to the documentation of this file.
1//===-- AbstractLatticeValue.h ----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
12#include "llzk/Util/Debug.h"
14
15#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
16#include <mlir/Support/LLVM.h>
17
18#include <llvm/Support/Debug.h>
19
20#include <concepts>
21#include <type_traits>
22#include <variant>
23
24#define DEBUG_TYPE "llzk-abstract-lattice-value"
25
26namespace llzk::dataflow {
27
28template <typename Val>
30 // Require default constructable
31 std::default_initializable<Val> && requires(Val lhs, Val rhs, mlir::raw_ostream &os) {
32 // Require a form of print function
33 { os << lhs } -> std::same_as<mlir::raw_ostream &>;
34 // Require comparability
35 { lhs == rhs } -> std::same_as<bool>;
36 // Require the ability to combine two scalar values
37 { lhs.join(rhs) } -> std::same_as<Val &>;
38 };
39
40template <typename Derived, ScalarLatticeValue ScalarTy> class AbstractLatticeValue {
48 using ArrayTy = std::vector<std::unique_ptr<Derived>>;
49
52 static ArrayTy constructArrayTy(const mlir::ArrayRef<int64_t> &shape) {
53 size_t totalElem = 1;
54 for (auto dim : shape) {
55 ensure(!mlir::ShapedType::isDynamic(dim), "Cannot pre-allocate dynamically-sized array");
56 totalElem *= dim;
57 }
58 ArrayTy arr(totalElem);
59 for (auto it = arr.begin(); it != arr.end(); it++) {
60 *it = std::make_unique<Derived>();
61 }
62 return arr;
63 }
64
65 static inline bool isDynamicArray(const mlir::ArrayRef<int64_t> &shape) {
66 return mlir::ShapedType::isDynamicShape(shape);
67 }
68
69public:
70 explicit AbstractLatticeValue(ScalarTy s)
71 : value(s), arrayShape(std::nullopt), isDynamic(false) {}
73 explicit AbstractLatticeValue(const mlir::ArrayRef<int64_t> shape)
74 : arrayShape(shape), isDynamic(isDynamicArray(shape)) {
75 if (isDynamic) {
76 value = ScalarTy();
77 } else {
78 value = constructArrayTy(shape);
79 }
80 }
81
82 AbstractLatticeValue(const AbstractLatticeValue &rhs) { *this = rhs; }
83
84 // Enable copying by duplicating unique_ptrs and copying the contained values.
86 copyArrayShape(rhs);
87 if (rhs.isScalar() || rhs.isDynamicArray()) {
88 getValue() = rhs.getScalarValue();
89 } else {
90 // create an empty array of the same size
91 getValue() = constructArrayTy(rhs.getArrayShape());
92 auto &lhsArr = getArrayValue();
93 auto &rhsArr = rhs.getArrayValue();
94 for (unsigned i = 0; i < lhsArr.size(); i++) {
95 // Recursive copy assignment of lattice values
96 *lhsArr[i] = *rhsArr[i];
97 }
98 }
99 return *this;
100 }
101
102 bool isScalar() const { return std::holds_alternative<ScalarTy>(value); }
103 bool isSingleValue() const { return isScalar() && getScalarValue().size() == 1; }
104 bool isArray() const { return std::holds_alternative<ArrayTy>(value); }
105 bool isDynamicArray() const { return isDynamic; }
106
107 const ScalarTy &getScalarValue() const {
108 ensure(isScalar(), "not a scalar value");
109 return std::get<ScalarTy>(value);
110 }
111
112 ScalarTy &getScalarValue() {
113 ensure(isScalar(), "not a scalar value");
114 return std::get<ScalarTy>(value);
115 }
116
117 const ArrayTy &getArrayValue() const {
118 ensure(isArray() && !isDynamicArray(), "not a static array value");
119 return std::get<ArrayTy>(value);
120 }
121
122 ArrayTy &getArrayValue() {
123 ensure(isArray() && !isDynamicArray(), "not a static array value");
124 return std::get<ArrayTy>(value);
125 }
126
128 const Derived &getElemFlatIdx(unsigned i) const {
129 ensure(isArray() && !isDynamicArray(), "not a static array value");
130 auto &arr = getArrayValue();
131 ensure(i < arr.size(), "index out of range");
132 return *arr.at(i);
133 }
134
135 Derived &getElemFlatIdx(unsigned i) {
136 ensure(isArray() && !isDynamicArray(), "not a static array value");
137 auto &arr = getArrayValue();
138 ensure(i < arr.size(), "index out of range");
139 return *arr.at(i);
140 }
141
142 size_t getArraySize() const { return getArrayValue().size(); }
143
144 size_t getNumArrayDims() const { return getArrayShape().size(); }
145
146 void print(mlir::raw_ostream &os) const {
147 if (isScalar() || isDynamicArray()) {
148 os << getScalarValue();
149 } else {
150 os << "[ ";
151 const auto &arr = getArrayValue();
152 for (auto it = arr.begin(); it != arr.end();) {
153 (*it)->print(os);
154 it++;
155 if (it != arr.end()) {
156 os << ", ";
157 } else {
158 os << ' ';
159 }
160 }
161 os << ']';
162 }
163 }
164
167 ScalarTy foldToScalar() const {
168 if (isScalar()) {
169 return getScalarValue();
170 }
171
172 ScalarTy res;
173 for (auto &val : getArrayValue()) {
174 auto rhs = val->foldToScalar();
175 res.join(rhs);
176 }
177 return res;
178 }
179
182 mlir::ChangeResult setValue(const AbstractLatticeValue &rhs) {
183 if (*this == rhs) {
184 return mlir::ChangeResult::NoChange;
185 }
186 *this = rhs;
187 return mlir::ChangeResult::Change;
188 }
189
191 mlir::ChangeResult update(const Derived &rhs) {
192 if (isScalar() && rhs.isScalar()) {
193 return updateScalar(rhs.getScalarValue());
194 } else if (isArray() && rhs.isArray() && getArraySize() == rhs.getArraySize()) {
195 return updateArray(rhs.getArrayValue());
196 } else {
197 return foldAndUpdate(rhs);
198 }
199 }
200
201 bool operator==(const AbstractLatticeValue &rhs) const {
202 if (isScalar() && rhs.isScalar()) {
203 return getScalarValue() == rhs.getScalarValue();
204 } else if (isArray() && rhs.isArray() && getArraySize() == rhs.getArraySize()) {
205 for (size_t i = 0; i < getArraySize(); i++) {
206 if (getElemFlatIdx(i) != rhs.getElemFlatIdx(i)) {
207 return false;
208 }
209 }
210 return true;
211 }
212 return false;
213 }
214
215protected:
216 std::variant<ScalarTy, ArrayTy> &getValue() { return value; }
217
218 const std::vector<int64_t> &getArrayShape() const {
219 ensure(arrayShape != std::nullopt, "not an array value");
220 return arrayShape.value();
221 }
222
223 int64_t getArrayDim(unsigned i) const {
224 const auto &arrShape = getArrayShape();
225 ensure(i < arrShape.size(), "dimension index out of bounds");
226 return arrShape.at(i);
227 }
228
230 arrayShape = rhs.arrayShape;
231 isDynamic = rhs.isDynamic;
232 }
233
235 mlir::ChangeResult updateScalar(const ScalarTy &rhs) {
236 auto lhs = getScalarValue();
237 lhs.join(rhs);
238 if (getScalarValue() == lhs) {
239 return mlir::ChangeResult::NoChange;
240 }
241 getScalarValue() = lhs;
242 return mlir::ChangeResult::Change;
243 }
244
246 mlir::ChangeResult updateArray(const ArrayTy &rhs) {
247 mlir::ChangeResult res = mlir::ChangeResult::NoChange;
248 auto &lhs = getArrayValue();
249 for (size_t i = 0; i < getArraySize(); i++) {
250 res |= lhs[i]->update(*rhs.at(i));
251 }
252 return res;
253 }
254
257 mlir::ChangeResult foldAndUpdate(const Derived &rhs) {
258 auto folded = foldToScalar();
259 auto rhsScalar = rhs.foldToScalar();
260 folded.join(rhsScalar);
261 if (isScalar() && getScalarValue() == folded) {
262 return mlir::ChangeResult::NoChange;
263 }
264 getValue() = folded;
265 return mlir::ChangeResult::Change;
266 }
267
268private:
269 std::variant<ScalarTy, ArrayTy> value;
270 std::optional<std::vector<int64_t>> arrayShape;
271 bool isDynamic;
272};
273
274template <typename Derived, ScalarLatticeValue ScalarTy>
275mlir::raw_ostream &
276operator<<(mlir::raw_ostream &os, const AbstractLatticeValue<Derived, ScalarTy> &v) {
277 v.print(os);
278 return os;
279}
280
281} // namespace llzk::dataflow
282
283#undef DEBUG_TYPE
bool operator==(const AbstractLatticeValue &rhs) const
AbstractLatticeValue(const AbstractLatticeValue &rhs)
mlir::ChangeResult updateArray(const ArrayTy &rhs)
Union this value with the given array.
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
Union this value with the given scalar.
const std::vector< int64_t > & getArrayShape() const
AbstractLatticeValue(const mlir::ArrayRef< int64_t > shape)
std::variant< ScalarTy, ArrayTy > & getValue()
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const Derived &rhs)
Folds the current value into a scalar and folds rhs to a scalar and updates the current value to the ...
void copyArrayShape(const AbstractLatticeValue &rhs)
const Derived & getElemFlatIdx(unsigned i) const
Directly index into the flattened array using a single index.
void print(mlir::raw_ostream &os) const
AbstractLatticeValue & operator=(const AbstractLatticeValue &rhs)
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const AbstractLatticeValue< Derived, ScalarTy > &v)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
bool isDynamic(IntegerAttr intAttr)