27#include <mlir/IR/BuiltinOps.h>
28#include <mlir/Pass/AnalysisManager.h>
29#include <mlir/Support/LLVM.h>
31#include <llvm/ADT/MapVector.h>
32#include <llvm/Support/SMTAPI.h>
55 llvm::APSInt
prime()
const {
return primeMod; }
58 llvm::APSInt
half()
const {
return halfPrime; }
61 inline llvm::APSInt
felt(
unsigned i)
const {
return reduce(i); }
64 inline llvm::APSInt
zero()
const {
return felt(0); }
67 inline llvm::APSInt
one()
const {
return felt(1); }
73 llvm::APSInt
reduce(llvm::APSInt i)
const;
74 llvm::APSInt
reduce(
unsigned i)
const;
76 inline unsigned bitWidth()
const {
return primeMod.getBitWidth(); }
79 llvm::SMTExprRef
createSymbol(llvm::SMTSolverRef solver,
const char *name)
const {
80 return solver->mkSymbol(name, solver->getBitvectorSort(
bitWidth()));
84 return lhs.primeMod == rhs.primeMod;
88 Field(std::string_view primeStr);
89 Field(llvm::APSInt p, llvm::APSInt h) : primeMod(p), halfPrime(h) {}
91 llvm::APSInt primeMod, halfPrime;
93 static void initKnownFields(llvm::DenseMap<llvm::StringRef, Field> &knownFields);
113 {lhs.a.getBitWidth(), lhs.b.getBitWidth(), rhs.a.getBitWidth(), rhs.b.getBitWidth()}
188 friend std::strong_ordering
192 return std::is_eq(lhs <=> rhs);
196 llvm::APSInt
getLHS()
const {
return a; }
197 llvm::APSInt
getRHS()
const {
return b; }
201 llvm::APSInt
width()
const;
208 void print(llvm::raw_ostream &os)
const { os <<
"Unreduced:[ " << a <<
", " << b <<
" ]"; }
293 static constexpr std::array<std::string_view, 7>
TypeNames = {
"TypeA",
"TypeB",
"TypeC",
294 "TypeF",
"Empty",
"Degenerate",
342 template <std::pair<Type, Type>... Pairs>
344 return ((a.ty == std::get<0>(Pairs) && b.ty == std::get<1>(Pairs)) || ...);
389 template <
Type... Types>
bool is()
const {
return ((ty == Types) || ...); }
397 llvm::APSInt
width()
const {
return llvm::APSInt((b - a).abs().zext(field.get().bitWidth())); }
399 llvm::APSInt
lhs()
const {
return a; }
400 llvm::APSInt
rhs()
const {
return b; }
405 return std::hash<const Field *> {}(&i.field.get()) ^ std::hash<Type> {}(i.ty) ^
406 llvm::hash_value(i.a) ^ llvm::hash_value(i.b);
410 void print(mlir::raw_ostream &os)
const;
418 Interval(Type t,
const Field &f) : field(f), ty(t), a(f.zero()), b(f.zero()) {}
420 : field(f), ty(t), a(
lhs.extend(f.bitWidth())), b(
rhs.extend(f.bitWidth())) {}
422 std::reference_wrapper<const Field> field;
437 : i(
Interval::Entire(f)), expr(exprRef) {}
440 : i(
Interval::Degenerate(f, singleVal)), expr(exprRef) {}
444 llvm::SMTExprRef
getExpr()
const {
return expr; }
515 llvm::SMTSolverRef solver, mlir::Operation *op,
const ExpressionValue &lhs,
528 void print(mlir::raw_ostream &os)
const;
543 llvm::SMTExprRef expr;
551 using AbstractLatticeValue::AbstractLatticeValue;
564 using ValueMap = mlir::DenseMap<mlir::Value, LatticeValue>;
570 using AbstractDenseLattice::AbstractDenseLattice;
572 mlir::ChangeResult
join(
const AbstractDenseLattice &other)
override;
574 mlir::ChangeResult
meet(
const AbstractDenseLattice &rhs)
override {
575 llvm::report_fatal_error(
"IntervalDataFlowAnalysis::meet : unsupported");
576 return mlir::ChangeResult::NoChange;
579 void print(mlir::raw_ostream &os)
const override;
581 mlir::FailureOr<LatticeValue>
getValue(mlir::Value v)
const;
594 mlir::FailureOr<Interval>
findInterval(llvm::SMTExprRef expr)
const;
611 using SymbolMap = mlir::DenseMap<ConstrainRef, llvm::SMTExprRef>;
615 mlir::DataFlowSolver &solver, llvm::SMTSolverRef smt,
const Field &f
617 : Base::DenseForwardDataFlowAnalysis(solver), dataflowSolver(solver), smtSolver(smt),
625 void visitOperation(mlir::Operation *op,
const Lattice &before, Lattice *after)
override;
634 mlir::DataFlowSolver &dataflowSolver;
635 llvm::SMTSolverRef smtSolver;
636 SymbolMap refSymbols;
637 std::reference_wrapper<const Field> field;
639 void setToEntryState(Lattice *lattice)
override {
643 llvm::SMTExprRef createFeltSymbol(
const ConstrainRef &r)
const;
645 llvm::SMTExprRef createFeltSymbol(mlir::Value val)
const;
647 llvm::SMTExprRef createFeltSymbol(
const char *name)
const;
649 bool isConstOp(mlir::Operation *op)
const {
651 felt::FeltConstantOp, mlir::arith::ConstantIndexOp, mlir::arith::ConstantIntOp>(op);
654 llvm::APSInt getConst(mlir::Operation *op)
const;
656 llvm::SMTExprRef createConstBitvectorExpr(llvm::APSInt v)
const {
657 return smtSolver->mkBitvector(v, field.get().bitWidth());
660 llvm::SMTExprRef createConstBoolExpr(
bool v)
const {
661 return smtSolver->mkBitvector(mlir::APSInt((
int)v), field.get().bitWidth());
664 bool isArithmeticOp(mlir::Operation *op)
const {
666 felt::AddFeltOp, felt::SubFeltOp, felt::MulFeltOp, felt::DivFeltOp, felt::ModFeltOp,
667 felt::NegFeltOp, felt::InvFeltOp, felt::AndFeltOp, felt::OrFeltOp, felt::XorFeltOp,
668 felt::NotFeltOp, felt::ShlFeltOp, felt::ShrFeltOp, boolean::CmpOp>(op);
672 performBinaryArithmetic(mlir::Operation *op,
const LatticeValue &a,
const LatticeValue &b);
674 ExpressionValue performUnaryArithmetic(mlir::Operation *op,
const LatticeValue &a);
683 applyInterval(mlir::Operation *originalOp, Lattice *after, mlir::Value val, Interval newInterval);
685 bool isBoolOp(mlir::Operation *op)
const {
686 return mlir::isa<boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(
691 bool isConversionOp(mlir::Operation *op)
const {
692 return mlir::isa<cast::IntToFeltOp, cast::FeltToIndexOp>(op);
695 bool isApplyMapOp(mlir::Operation *op)
const {
return mlir::isa<polymorphic::ApplyMapOp>(op); }
697 bool isAssertOp(mlir::Operation *op)
const {
return mlir::isa<boolean::AssertOp>(op); }
699 bool isReadOp(mlir::Operation *op)
const {
700 return mlir::isa<component::FieldReadOp, polymorphic::ConstReadOp, array::ReadArrayOp>(op);
703 bool isWriteOp(mlir::Operation *op)
const {
704 return mlir::isa<component::FieldWriteOp, array::WriteArrayOp, array::InsertArrayOp>(op);
707 bool isArrayLengthOp(mlir::Operation *op)
const {
return mlir::isa<array::ArrayLengthOp>(op); }
709 bool isEmitOp(mlir::Operation *op)
const {
710 return mlir::isa<constrain::EmitEqualityOp, constrain::EmitContainmentOp>(op);
713 bool isCreateOp(mlir::Operation *op)
const {
714 return mlir::isa<component::CreateStructOp, array::CreateArrayOp>(op);
717 bool isExtractArrayOp(mlir::Operation *op)
const {
return mlir::isa<array::ExtractArrayOp>(op); }
719 bool isDefinitionOp(mlir::Operation *op)
const {
721 component::StructDefOp, function::FuncDefOp, component::FieldDefOp, global::GlobalDefOp,
725 bool isCallOp(mlir::Operation *op)
const {
return mlir::isa<function::CallOp>(op); }
727 bool isReturnOp(mlir::Operation *op)
const {
return mlir::isa<function::ReturnOp>(op); }
734 bool isConsideredOp(mlir::Operation *op)
const {
735 return isConstOp(op) || isArithmeticOp(op) || isBoolOp(op) || isConversionOp(op) ||
736 isApplyMapOp(op) || isAssertOp(op) || isReadOp(op) || isWriteOp(op) ||
737 isArrayLengthOp(op) || isEmitOp(op) || isCreateOp(op) || isDefinitionOp(op) ||
738 isCallOp(op) || isReturnOp(op) || isExtractArrayOp(op);
748 std::reference_wrapper<const Field>
field;
754class StructIntervals {
765 static mlir::FailureOr<StructIntervals>
compute(
769 StructIntervals si(mod, s);
771 return mlir::failure();
780 void print(mlir::raw_ostream &os,
bool withConstraints =
false)
const;
783 return constrainFieldRanges;
787 return constrainSolverConstraints;
790 friend mlir::raw_ostream &
operator<<(mlir::raw_ostream &os,
const StructIntervals &si) {
798 llvm::SMTSolverRef smtSolver;
800 llvm::MapVector<ConstrainRef, Interval> constrainFieldRanges;
802 llvm::SetVector<ExpressionValue> constrainSolverConstraints;
816 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager,
821 if (mlir::failed(res)) {
822 return mlir::failure();
825 return mlir::success();
832 :
public ModuleAnalysis<StructIntervals, IntervalAnalysisContext, StructIntervalAnalysis> {
842 ensure(field.has_value(),
"field not set, could not generate analysis context");
844 auto smtSolverRef = smtSolver;
846 std::move(smtSolverRef), field.value()
851 ensure(field.has_value(),
"field not set, could not generate analysis context");
853 .intervalDFA = intervalDFA,
854 .smtSolver = smtSolver,
855 .field = field.value(),
860 llvm::SMTSolverRef smtSolver;
862 std::optional<std::reference_wrapper<const Field>> field;
869template <>
struct DenseMapInfo<
llzk::ExpressionValue> {
872 static auto emptyPtr =
reinterpret_cast<SMTExprRef
>(1);
876 static auto tombstonePtr =
reinterpret_cast<SMTExprRef
>(2);
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
Convenience classes for a frequent pattern of dataflow analysis used in LLZK, where an analysis is ru...
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
The dataflow analysis that computes the set of references that LLZK operations use and produce.
Defines a reference to a llzk object within a constrain function call.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
friend ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
const Interval & getInterval() const
friend ExpressionValue cmp(llvm::SMTSolverRef solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
friend ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, mlir::Operation *op, const ExpressionValue &val)
friend ExpressionValue div(llvm::SMTSolverRef solver, felt::DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
friend ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
Computes a solver expression based on the operation, but computes a fallback interval (which is just ...
ExpressionValue & join(const ExpressionValue &rhs)
Fold two expressions together when overapproximating array elements.
ExpressionValue(llvm::SMTExprRef exprRef, Interval interval)
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, llvm::APSInt singleVal)
friend ExpressionValue join(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the union of the lhs and rhs intervals, and create a solver expression that constrains both s...
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ExpressionValue &e)
const Field & getField() const
friend ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
Information about the prime finite field used for the interval analysis.
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
llvm::APSInt felt(unsigned i) const
Returns i as a field element.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::APSInt zero() const
Returns 0 at the bitwidth of the field.
friend bool operator==(const Field &lhs, const Field &rhs)
Field(const Field &)=default
llvm::APSInt half() const
Returns p / 2.
unsigned bitWidth() const
Field & operator=(const Field &)=default
llvm::APSInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
static const Field & getField(const char *fieldName)
Get a Field from a given field name string.
llvm::APSInt prime() const
For the prime field p, returns p.
llvm::APSInt reduce(llvm::APSInt i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
Maps mlir::Values to LatticeValues.
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const IntervalAnalysisLattice &l)
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
llvm::SetVector< ExpressionValue > ConstraintSet
IntervalAnalysisLatticeValue LatticeValue
mlir::DenseMap< mlir::Value, LatticeValue > ValueMap
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::DenseMap< llvm::SMTExprRef, Interval > ExpressionIntervals
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
mlir::ChangeResult meet(const AbstractDenseLattice &rhs) override
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
IntervalDataFlowAnalysis(mlir::DataFlowSolver &solver, llvm::SMTSolverRef smt, const Field &f)
Intervals over a finite field.
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const Interval &i)
static constexpr std::array< std::string_view, 7 > TypeNames
Interval intersect(const Interval &rhs) const
Intersect.
static std::string_view TypeName(Type t)
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
bool isDegenerate() const
void print(mlir::raw_ostream &os) const
const Field & getField() const
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
static Interval TypeC(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeF(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeB(const Field &f, llvm::APSInt a, llvm::APSInt b)
bool operator==(const Interval &rhs) const
llvm::APSInt width() const
friend Interval operator*(const Interval &lhs, const Interval &rhs)
static Interval Empty(const Field &f)
static bool areOneOf(const Interval &a, const Interval &b)
static Interval Degenerate(const Field &f, llvm::APSInt val)
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
friend mlir::FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Returns failure if a division-by-zero is encountered.
static Interval TypeA(const Field &f, llvm::APSInt a, llvm::APSInt b)
friend Interval operator+(const Interval &lhs, const Interval &rhs)
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
friend Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
ModuleAnalysis(mlir::Operation *op)
ModuleIntervalAnalysis(mlir::Operation *op)
IntervalAnalysisContext getContext() override
Create and return a valid Context object.
void initializeSolver(mlir::DataFlowSolver &solver) override
Initialize the shared dataflow solver with any common analyses required by the contained struct analy...
void setField(const Field &f)
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
component::StructDefOp getStruct() const
mlir::ModuleOp getModule() const
void setResult(StructIntervals &&r)
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, IntervalAnalysisContext &ctx) override
Perform the analysis and construct the Result output.
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
const llvm::SetVector< ExpressionValue > getSolverConstraints() const
const llvm::MapVector< ConstrainRef, Interval > & getIntervals() const
void print(mlir::raw_ostream &os, bool withConstraints=false) const
static mlir::FailureOr< StructIntervals > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
Compute the struct intervals.
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const StructIntervals &si)
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
UnreducedInterval operator-() const
friend UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
UnreducedInterval(llvm::APInt x, llvm::APInt y)
bool isEmpty() const
Returns true iff width() is zero.
UnreducedInterval(llvm::APSInt x, llvm::APSInt y)
llvm::APSInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
friend std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval(uint64_t x, uint64_t y)
This constructor is primarily for convenience for unit tests.
bool overlaps(const UnreducedInterval &rhs) const
llvm::APSInt getRHS() const
friend llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const UnreducedInterval &ui)
llvm::APSInt getLHS() const
friend bool operator==(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
void print(llvm::raw_ostream &os) const
static size_t getMaxBitWidth(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
A utility method to determine the largest bitwidth among arms of two UnreducedIntervals.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
friend UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
LLZK: This class has been ported so that it can inherit from our port of the AbstractDenseForwardData...
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
void ensure(bool condition, llvm::Twine errMsg)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static unsigned getHashValue(const llzk::ExpressionValue &e)
static SMTExprRef getTombstoneExpr()
static bool isEqual(const llzk::ExpressionValue &lhs, const llzk::ExpressionValue &rhs)
static llzk::ExpressionValue getTombstoneKey()
static llzk::ExpressionValue getEmptyKey()
static SMTExprRef getEmptyExpr()
unsigned operator()(const ExpressionValue &e) const
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA
llvm::SMTExprRef getSymbol(const ConstrainRef &r)
llvm::SMTSolverRef smtSolver
unsigned operator()(const Interval &i) const