15#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
16#include <mlir/Support/LLVM.h>
18#include <llvm/Support/Debug.h>
24#define DEBUG_TYPE "llzk-abstract-lattice-value"
28template <
typename Val>
31 std::default_initializable<Val> &&
requires(Val lhs, Val rhs, mlir::raw_ostream &os) {
33 { os << lhs } -> std::same_as<mlir::raw_ostream &>;
35 { lhs == rhs } -> std::same_as<bool>;
37 { lhs.join(rhs) } -> std::same_as<Val &>;
48 using ArrayTy = std::vector<std::unique_ptr<Derived>>;
52 static ArrayTy constructArrayTy(
const mlir::ArrayRef<int64_t> &shape) {
54 for (
auto dim : shape) {
55 ensure(!mlir::ShapedType::isDynamic(dim),
"Cannot pre-allocate dynamically-sized array");
58 ArrayTy arr(totalElem);
59 for (
auto it = arr.begin(); it != arr.end(); it++) {
60 *it = std::make_unique<Derived>();
65 static inline bool isDynamicArray(
const mlir::ArrayRef<int64_t> &shape) {
66 return mlir::ShapedType::isDynamicShape(shape);
71 : value(s), arrayShape(std::nullopt), isDynamic(false) {}
74 : arrayShape(shape), isDynamic(isDynamicArray(shape)) {
78 value = constructArrayTy(shape);
87 if (rhs.
isScalar() || rhs.isDynamicArray()) {
94 for (
unsigned i = 0; i < lhsArr.size(); i++) {
96 *lhsArr[i] = *rhsArr[i];
102 bool isScalar()
const {
return std::holds_alternative<ScalarTy>(value); }
104 bool isArray()
const {
return std::holds_alternative<ArrayTy>(value); }
109 return std::get<ScalarTy>(value);
114 return std::get<ScalarTy>(value);
119 return std::get<ArrayTy>(value);
124 return std::get<ArrayTy>(value);
131 ensure(i < arr.size(),
"index out of range");
138 ensure(i < arr.size(),
"index out of range");
146 void print(mlir::raw_ostream &os)
const {
152 for (
auto it = arr.begin(); it != arr.end();) {
155 if (it != arr.end()) {
174 auto rhs = val->foldToScalar();
184 return mlir::ChangeResult::NoChange;
187 return mlir::ChangeResult::Change;
191 mlir::ChangeResult
update(
const Derived &rhs) {
216 std::variant<ScalarTy, ArrayTy> &
getValue() {
return value; }
219 ensure(arrayShape != std::nullopt,
"not an array value");
220 return arrayShape.value();
225 ensure(i < arrShape.size(),
"dimension index out of bounds");
226 return arrShape.at(i);
230 arrayShape = rhs.arrayShape;
231 isDynamic = rhs.isDynamic;
239 return mlir::ChangeResult::NoChange;
242 return mlir::ChangeResult::Change;
247 mlir::ChangeResult res = mlir::ChangeResult::NoChange;
250 res |= lhs[i]->update(*rhs.at(i));
259 auto rhsScalar = rhs.foldToScalar();
260 folded.join(rhsScalar);
262 return mlir::ChangeResult::NoChange;
265 return mlir::ChangeResult::Change;
269 std::variant<ScalarTy, ArrayTy> value;
270 std::optional<std::vector<int64_t>> arrayShape;
274template <
typename Derived, ScalarLatticeValue ScalarTy>
bool operator==(const AbstractLatticeValue &rhs) const
AbstractLatticeValue(const AbstractLatticeValue &rhs)
size_t getNumArrayDims() const
mlir::ChangeResult updateArray(const ArrayTy &rhs)
Union this value with the given array.
size_t getArraySize() const
Derived & getElemFlatIdx(unsigned i)
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
Union this value with the given scalar.
const std::vector< int64_t > & getArrayShape() const
int64_t getArrayDim(unsigned i) const
AbstractLatticeValue(const mlir::ArrayRef< int64_t > shape)
std::variant< ScalarTy, ArrayTy > & getValue()
AbstractLatticeValue(ScalarTy s)
ScalarTy & getScalarValue()
bool isDynamicArray() const
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const Derived &rhs)
Folds the current value into a scalar and folds rhs to a scalar and updates the current value to the ...
void copyArrayShape(const AbstractLatticeValue &rhs)
bool isSingleValue() const
const Derived & getElemFlatIdx(unsigned i) const
Directly index into the flattened array using a single index.
const ScalarTy & getScalarValue() const
const ArrayTy & getArrayValue() const
ArrayTy & getArrayValue()
void print(mlir::raw_ostream &os) const
AbstractLatticeValue & operator=(const AbstractLatticeValue &rhs)
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const AbstractLatticeValue< Derived, ScalarTy > &v)
void ensure(bool condition, llvm::Twine errMsg)
bool isDynamic(IntegerAttr intAttr)