15#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
16#include <mlir/Support/LLVM.h>
18#include <llvm/Support/Debug.h>
24#define DEBUG_TYPE "llzk-abstract-lattice-value"
28template <
typename Val>
31 { os << lhs } -> std::same_as<mlir::raw_ostream &>;
33 { lhs == rhs } -> std::same_as<bool>;
35 { lhs.join(rhs) } -> std::same_as<Val &>;
37 requires std::default_initializable<Val>;
48 using ArrayTy = std::vector<std::unique_ptr<Derived>>;
52 static ArrayTy constructArrayTy(
const mlir::ArrayRef<int64_t> &shape) {
54 for (
auto dim : shape) {
57 ArrayTy arr(totalElem);
58 for (
auto it = arr.begin(); it != arr.end(); it++) {
59 *it = std::make_unique<Derived>();
68 : value(constructArrayTy(shape)), arrayShape(shape) {}
82 for (
unsigned i = 0; i < lhsArr.size(); i++) {
84 *lhsArr[i] = *rhsArr[i];
90 bool isScalar()
const {
return std::holds_alternative<ScalarTy>(value); }
92 bool isArray()
const {
return std::holds_alternative<ArrayTy>(value); }
96 return std::get<ScalarTy>(value);
101 return std::get<ScalarTy>(value);
106 return std::get<ArrayTy>(value);
111 return std::get<ArrayTy>(value);
118 ensure(i < arr.size(),
"index out of range");
125 ensure(i < arr.size(),
"index out of range");
133 void print(mlir::raw_ostream &os)
const {
139 for (
auto it = arr.begin(); it != arr.end();) {
142 if (it != arr.end()) {
161 auto rhs = val->foldToScalar();
171 return mlir::ChangeResult::NoChange;
174 return mlir::ChangeResult::Change;
178 mlir::ChangeResult
update(
const Derived &rhs) {
203 std::variant<ScalarTy, ArrayTy> &
getValue() {
return value; }
206 ensure(arrayShape != std::nullopt,
"not an array value");
207 return arrayShape.value();
212 ensure(i < arrShape.size(),
"dimension index out of bounds");
213 return arrShape.at(i);
223 return mlir::ChangeResult::NoChange;
226 return mlir::ChangeResult::Change;
231 mlir::ChangeResult res = mlir::ChangeResult::NoChange;
234 res |= lhs[i]->update(*rhs.at(i));
243 auto rhsScalar = rhs.foldToScalar();
244 folded.join(rhsScalar);
246 return mlir::ChangeResult::NoChange;
249 return mlir::ChangeResult::Change;
253 std::variant<ScalarTy, ArrayTy> value;
254 std::optional<std::vector<int64_t>> arrayShape;
257template <
typename Derived, ScalarLatticeValue ScalarTy>
bool operator==(const AbstractLatticeValue &rhs) const
AbstractLatticeValue(const AbstractLatticeValue &rhs)
size_t getNumArrayDims() const
mlir::ChangeResult updateArray(const ArrayTy &rhs)
Union this value with the given array.
size_t getArraySize() const
Derived & getElemFlatIdx(unsigned i)
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
Union this value with the given scalar.
const std::vector< int64_t > & getArrayShape() const
int64_t getArrayDim(unsigned i) const
AbstractLatticeValue(const mlir::ArrayRef< int64_t > shape)
std::variant< ScalarTy, ArrayTy > & getValue()
AbstractLatticeValue(ScalarTy s)
ScalarTy & getScalarValue()
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const Derived &rhs)
Folds the current value into a scalar and folds rhs to a scalar and updates the current value to the ...
void copyArrayShape(const AbstractLatticeValue &rhs)
bool isSingleValue() const
const Derived & getElemFlatIdx(unsigned i) const
Directly index into the flattened array using a single index.
const ScalarTy & getScalarValue() const
const ArrayTy & getArrayValue() const
ArrayTy & getArrayValue()
void print(mlir::raw_ostream &os) const
AbstractLatticeValue & operator=(const AbstractLatticeValue &rhs)
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const AbstractLatticeValue< Derived, ScalarTy > &v)
void ensure(bool condition, llvm::Twine errMsg)