LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
AbstractLatticeValue.h
Go to the documentation of this file.
1//===-- AbstractLatticeValue.h ----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
12#include "llzk/Util/Debug.h"
14
15#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
16#include <mlir/Support/LLVM.h>
17
18#include <llvm/Support/Debug.h>
19
20#include <concepts>
21#include <type_traits>
22#include <variant>
23
24#define DEBUG_TYPE "llzk-abstract-lattice-value"
25
26namespace llzk::dataflow {
27
28template <typename Val>
29concept ScalarLatticeValue = requires(Val lhs, Val rhs, mlir::raw_ostream &os) {
30 // Require a form of print function
31 { os << lhs } -> std::same_as<mlir::raw_ostream &>;
32 // Require comparability
33 { lhs == rhs } -> std::same_as<bool>;
34 // Require the ability to combine two scalar values
35 { lhs.join(rhs) } -> std::same_as<Val &>;
36 // Require default constructable
37 requires std::default_initializable<Val>;
38};
39
40template <typename Derived, ScalarLatticeValue ScalarTy> class AbstractLatticeValue {
48 using ArrayTy = std::vector<std::unique_ptr<Derived>>;
49
52 static ArrayTy constructArrayTy(const mlir::ArrayRef<int64_t> &shape) {
53 size_t totalElem = 1;
54 for (auto dim : shape) {
55 totalElem *= dim;
56 }
57 ArrayTy arr(totalElem);
58 for (auto it = arr.begin(); it != arr.end(); it++) {
59 *it = std::make_unique<Derived>();
60 }
61 return arr;
62 }
63
64public:
65 explicit AbstractLatticeValue(ScalarTy s) : value(s), arrayShape(std::nullopt) {}
67 explicit AbstractLatticeValue(const mlir::ArrayRef<int64_t> shape)
68 : value(constructArrayTy(shape)), arrayShape(shape) {}
69
70 AbstractLatticeValue(const AbstractLatticeValue &rhs) { *this = rhs; }
71
72 // Enable copying by duplicating unique_ptrs and copying the contained values.
74 copyArrayShape(rhs);
75 if (rhs.isScalar()) {
76 getValue() = rhs.getScalarValue();
77 } else {
78 // create an empty array of the same size
79 getValue() = constructArrayTy(rhs.getArrayShape());
80 auto &lhsArr = getArrayValue();
81 auto &rhsArr = rhs.getArrayValue();
82 for (unsigned i = 0; i < lhsArr.size(); i++) {
83 // Recursive copy assignment of lattice values
84 *lhsArr[i] = *rhsArr[i];
85 }
86 }
87 return *this;
88 }
89
90 bool isScalar() const { return std::holds_alternative<ScalarTy>(value); }
91 bool isSingleValue() const { return isScalar() && getScalarValue().size() == 1; }
92 bool isArray() const { return std::holds_alternative<ArrayTy>(value); }
93
94 const ScalarTy &getScalarValue() const {
95 ensure(isScalar(), "not a scalar value");
96 return std::get<ScalarTy>(value);
97 }
98
99 ScalarTy &getScalarValue() {
100 ensure(isScalar(), "not a scalar value");
101 return std::get<ScalarTy>(value);
102 }
103
104 const ArrayTy &getArrayValue() const {
105 ensure(isArray(), "not an array value");
106 return std::get<ArrayTy>(value);
107 }
108
109 ArrayTy &getArrayValue() {
110 ensure(isArray(), "not an array value");
111 return std::get<ArrayTy>(value);
112 }
113
115 const Derived &getElemFlatIdx(unsigned i) const {
116 ensure(isArray(), "not an array value");
117 auto &arr = getArrayValue();
118 ensure(i < arr.size(), "index out of range");
119 return *arr.at(i);
120 }
121
122 Derived &getElemFlatIdx(unsigned i) {
123 ensure(isArray(), "not an array value");
124 auto &arr = getArrayValue();
125 ensure(i < arr.size(), "index out of range");
126 return *arr.at(i);
127 }
128
129 size_t getArraySize() const { return getArrayValue().size(); }
130
131 size_t getNumArrayDims() const { return getArrayShape().size(); }
132
133 void print(mlir::raw_ostream &os) const {
134 if (isScalar()) {
135 os << getScalarValue();
136 } else {
137 os << "[ ";
138 const auto &arr = getArrayValue();
139 for (auto it = arr.begin(); it != arr.end();) {
140 (*it)->print(os);
141 it++;
142 if (it != arr.end()) {
143 os << ", ";
144 } else {
145 os << ' ';
146 }
147 }
148 os << ']';
149 }
150 }
151
154 ScalarTy foldToScalar() const {
155 if (isScalar()) {
156 return getScalarValue();
157 }
158
159 ScalarTy res;
160 for (auto &val : getArrayValue()) {
161 auto rhs = val->foldToScalar();
162 res.join(rhs);
163 }
164 return res;
165 }
166
169 mlir::ChangeResult setValue(const AbstractLatticeValue &rhs) {
170 if (*this == rhs) {
171 return mlir::ChangeResult::NoChange;
172 }
173 *this = rhs;
174 return mlir::ChangeResult::Change;
175 }
176
178 mlir::ChangeResult update(const Derived &rhs) {
179 if (isScalar() && rhs.isScalar()) {
180 return updateScalar(rhs.getScalarValue());
181 } else if (isArray() && rhs.isArray() && getArraySize() == rhs.getArraySize()) {
182 return updateArray(rhs.getArrayValue());
183 } else {
184 return foldAndUpdate(rhs);
185 }
186 }
187
188 bool operator==(const AbstractLatticeValue &rhs) const {
189 if (isScalar() && rhs.isScalar()) {
190 return getScalarValue() == rhs.getScalarValue();
191 } else if (isArray() && rhs.isArray() && getArraySize() == rhs.getArraySize()) {
192 for (size_t i = 0; i < getArraySize(); i++) {
193 if (getElemFlatIdx(i) != rhs.getElemFlatIdx(i)) {
194 return false;
195 }
196 }
197 return true;
198 }
199 return false;
200 }
201
202protected:
203 std::variant<ScalarTy, ArrayTy> &getValue() { return value; }
204
205 const std::vector<int64_t> &getArrayShape() const {
206 ensure(arrayShape != std::nullopt, "not an array value");
207 return arrayShape.value();
208 }
209
210 int64_t getArrayDim(unsigned i) const {
211 const auto &arrShape = getArrayShape();
212 ensure(i < arrShape.size(), "dimension index out of bounds");
213 return arrShape.at(i);
214 }
215
216 void copyArrayShape(const AbstractLatticeValue &rhs) { arrayShape = rhs.arrayShape; }
217
219 mlir::ChangeResult updateScalar(const ScalarTy &rhs) {
220 auto lhs = getScalarValue();
221 lhs.join(rhs);
222 if (getScalarValue() == lhs) {
223 return mlir::ChangeResult::NoChange;
224 }
225 getScalarValue() = lhs;
226 return mlir::ChangeResult::Change;
227 }
228
230 mlir::ChangeResult updateArray(const ArrayTy &rhs) {
231 mlir::ChangeResult res = mlir::ChangeResult::NoChange;
232 auto &lhs = getArrayValue();
233 for (size_t i = 0; i < getArraySize(); i++) {
234 res |= lhs[i]->update(*rhs.at(i));
235 }
236 return res;
237 }
238
241 mlir::ChangeResult foldAndUpdate(const Derived &rhs) {
242 auto folded = foldToScalar();
243 auto rhsScalar = rhs.foldToScalar();
244 folded.join(rhsScalar);
245 if (isScalar() && getScalarValue() == folded) {
246 return mlir::ChangeResult::NoChange;
247 }
248 getValue() = folded;
249 return mlir::ChangeResult::Change;
250 }
251
252private:
253 std::variant<ScalarTy, ArrayTy> value;
254 std::optional<std::vector<int64_t>> arrayShape;
255};
256
257template <typename Derived, ScalarLatticeValue ScalarTy>
258mlir::raw_ostream &
259operator<<(mlir::raw_ostream &os, const AbstractLatticeValue<Derived, ScalarTy> &v) {
260 v.print(os);
261 return os;
262}
263
264} // namespace llzk::dataflow
265
266#undef DEBUG_TYPE
bool operator==(const AbstractLatticeValue &rhs) const
AbstractLatticeValue(const AbstractLatticeValue &rhs)
mlir::ChangeResult updateArray(const ArrayTy &rhs)
Union this value with the given array.
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
Union this value with the given scalar.
const std::vector< int64_t > & getArrayShape() const
AbstractLatticeValue(const mlir::ArrayRef< int64_t > shape)
std::variant< ScalarTy, ArrayTy > & getValue()
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const Derived &rhs)
Folds the current value into a scalar and folds rhs to a scalar and updates the current value to the ...
void copyArrayShape(const AbstractLatticeValue &rhs)
const Derived & getElemFlatIdx(unsigned i) const
Directly index into the flattened array using a single index.
void print(mlir::raw_ostream &os) const
AbstractLatticeValue & operator=(const AbstractLatticeValue &rhs)
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const AbstractLatticeValue< Derived, ScalarTy > &v)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32