LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
IntervalAnalysis.h
Go to the documentation of this file.
1//===-- IntervalAnalysis.h --------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
25#include "llzk/Util/Compare.h"
26
27#include <mlir/IR/BuiltinOps.h>
28#include <mlir/Pass/AnalysisManager.h>
29#include <mlir/Support/LLVM.h>
30
31#include <llvm/ADT/MapVector.h>
32#include <llvm/Support/SMTAPI.h>
33
34#include <array>
35#include <mutex>
36
37namespace llzk {
38
39/* Field */
40
43class Field {
44public:
47 static const Field &getField(const char *fieldName);
48
49 Field() = delete;
50 Field(const Field &) = default;
51 Field(Field &&) = default;
52 Field &operator=(const Field &) = default;
53
55 llvm::APSInt prime() const { return primeMod; }
56
58 llvm::APSInt half() const { return halfPrime; }
59
61 inline llvm::APSInt felt(unsigned i) const { return reduce(i); }
62
64 inline llvm::APSInt zero() const { return felt(0); }
65
67 inline llvm::APSInt one() const { return felt(1); }
68
70 inline llvm::APSInt maxVal() const { return prime() - one(); }
71
73 llvm::APSInt reduce(llvm::APSInt i) const;
74 llvm::APSInt reduce(unsigned i) const;
75
76 inline unsigned bitWidth() const { return primeMod.getBitWidth(); }
77
79 llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const {
80 return solver->mkSymbol(name, solver->getBitvectorSort(bitWidth()));
81 }
82
83 friend bool operator==(const Field &lhs, const Field &rhs) {
84 return lhs.primeMod == rhs.primeMod;
85 }
86
87private:
88 Field(std::string_view primeStr);
89 Field(llvm::APSInt p, llvm::APSInt h) : primeMod(p), halfPrime(h) {}
90
91 llvm::APSInt primeMod, halfPrime;
92
93 static void initKnownFields(llvm::DenseMap<llvm::StringRef, Field> &knownFields);
94};
95
96/* UnreducedInterval */
97
98class Interval;
99
103public:
111 static size_t getMaxBitWidth(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
112 return std::max(
113 {lhs.a.getBitWidth(), lhs.b.getBitWidth(), rhs.a.getBitWidth(), rhs.b.getBitWidth()}
114 );
115 }
116
117 UnreducedInterval(llvm::APSInt x, llvm::APSInt y) : a(x), b(y) {}
118 UnreducedInterval(llvm::APInt x, llvm::APInt y) : a(x), b(y) {}
120 UnreducedInterval(uint64_t x, uint64_t y) : a(llvm::APInt(64, x)), b(llvm::APInt(64, y)) {}
121
122 /* Operations */
123
127 Interval reduce(const Field &field) const;
128
133
138
148
158
168
178
180 friend UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs);
181 friend UnreducedInterval operator-(const UnreducedInterval &lhs, const UnreducedInterval &rhs);
182 friend UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs);
183
184 /* Comparisons */
185
186 bool overlaps(const UnreducedInterval &rhs) const;
187
188 friend std::strong_ordering
189 operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs);
190
191 friend bool operator==(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
192 return std::is_eq(lhs <=> rhs);
193 };
194
195 /* Utility */
196 llvm::APSInt getLHS() const { return a; }
197 llvm::APSInt getRHS() const { return b; }
198
201 llvm::APSInt width() const;
202
204 bool isEmpty() const;
205
206 bool isNotEmpty() const { return !isEmpty(); }
207
208 void print(llvm::raw_ostream &os) const { os << "Unreduced:[ " << a << ", " << b << " ]"; }
209
210 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const UnreducedInterval &ui) {
211 ui.print(os);
212 return os;
213 }
214
215private:
216 llvm::APSInt a, b;
217};
218
219/* Interval */
220
290class Interval {
291public:
292 enum class Type { TypeA = 0, TypeB, TypeC, TypeF, Empty, Degenerate, Entire };
293 static constexpr std::array<std::string_view, 7> TypeNames = {"TypeA", "TypeB", "TypeC",
294 "TypeF", "Empty", "Degenerate",
295 "Entire"};
296
297 static std::string_view TypeName(Type t) { return TypeNames.at(static_cast<size_t>(t)); }
298
299 /* Static constructors for convenience */
300
301 static Interval Empty(const Field &f) { return Interval(Type::Empty, f); }
302
303 static Interval Degenerate(const Field &f, llvm::APSInt val) {
304 return Interval(Type::Degenerate, f, val, val);
305 }
306
307 static Interval Boolean(const Field &f) { return Interval::TypeA(f, f.zero(), f.one()); }
308
309 static Interval Entire(const Field &f) { return Interval(Type::Entire, f); }
310
311 static Interval TypeA(const Field &f, llvm::APSInt a, llvm::APSInt b) {
312 return Interval(Type::TypeA, f, a, b);
313 }
314
315 static Interval TypeB(const Field &f, llvm::APSInt a, llvm::APSInt b) {
316 return Interval(Type::TypeB, f, a, b);
317 }
318
319 static Interval TypeC(const Field &f, llvm::APSInt a, llvm::APSInt b) {
320 return Interval(Type::TypeC, f, a, b);
321 }
322
323 static Interval TypeF(const Field &f, llvm::APSInt a, llvm::APSInt b) {
324 return Interval(Type::TypeF, f, a, b);
325 }
326
330
333
337
341
342 template <std::pair<Type, Type>... Pairs>
343 static bool areOneOf(const Interval &a, const Interval &b) {
344 return ((a.ty == std::get<0>(Pairs) && b.ty == std::get<1>(Pairs)) || ...);
345 }
346
348 Interval join(const Interval &rhs) const;
349
351 Interval intersect(const Interval &rhs) const;
352
366 Interval difference(const Interval &other) const;
367
368 /* arithmetic ops */
369
370 Interval operator-() const;
371 friend Interval operator+(const Interval &lhs, const Interval &rhs);
372 friend Interval operator-(const Interval &lhs, const Interval &rhs);
373 friend Interval operator*(const Interval &lhs, const Interval &rhs);
374 friend Interval operator%(const Interval &lhs, const Interval &rhs);
376 friend mlir::FailureOr<Interval> operator/(const Interval &lhs, const Interval &rhs);
377
378 /* Checks and Comparisons */
379
380 inline bool isEmpty() const { return ty == Type::Empty; }
381 inline bool isNotEmpty() const { return !isEmpty(); }
382 inline bool isDegenerate() const { return ty == Type::Degenerate; }
383 inline bool isEntire() const { return ty == Type::Entire; }
384 inline bool isTypeA() const { return ty == Type::TypeA; }
385 inline bool isTypeB() const { return ty == Type::TypeB; }
386 inline bool isTypeC() const { return ty == Type::TypeC; }
387 inline bool isTypeF() const { return ty == Type::TypeF; }
388
389 template <Type... Types> bool is() const { return ((ty == Types) || ...); }
390
391 bool operator==(const Interval &rhs) const { return ty == rhs.ty && a == rhs.a && b == rhs.b; }
392
393 /* Getters */
394
395 const Field &getField() const { return field.get(); }
396
397 llvm::APSInt width() const { return llvm::APSInt((b - a).abs().zext(field.get().bitWidth())); }
398
399 llvm::APSInt lhs() const { return a; }
400 llvm::APSInt rhs() const { return b; }
401
402 /* Utility */
403 struct Hash {
404 unsigned operator()(const Interval &i) const {
405 return std::hash<const Field *> {}(&i.field.get()) ^ std::hash<Type> {}(i.ty) ^
406 llvm::hash_value(i.a) ^ llvm::hash_value(i.b);
407 }
408 };
409
410 void print(mlir::raw_ostream &os) const;
411
412 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const Interval &i) {
413 i.print(os);
414 return os;
415 }
416
417private:
418 Interval(Type t, const Field &f) : field(f), ty(t), a(f.zero()), b(f.zero()) {}
419 Interval(Type t, const Field &f, llvm::APSInt lhs, llvm::APSInt rhs)
420 : field(f), ty(t), a(lhs.extend(f.bitWidth())), b(rhs.extend(f.bitWidth())) {}
421
422 std::reference_wrapper<const Field> field;
423 Type ty;
424 llvm::APSInt a, b;
425};
426
427/* ExpressionValue */
428
432public:
433 /* Must be default initializable to be a ScalarLatticeValue. */
434 ExpressionValue() : i(), expr(nullptr) {}
435
436 explicit ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
437 : i(Interval::Entire(f)), expr(exprRef) {}
438
439 ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, llvm::APSInt singleVal)
440 : i(Interval::Degenerate(f, singleVal)), expr(exprRef) {}
441
442 ExpressionValue(llvm::SMTExprRef exprRef, Interval interval) : i(interval), expr(exprRef) {}
443
444 llvm::SMTExprRef getExpr() const { return expr; }
445
446 const Interval &getInterval() const { return i; }
447
448 const Field &getField() const { return i.getField(); }
449
453 ExpressionValue withInterval(const Interval &newInterval) const {
454 return ExpressionValue(expr, newInterval);
455 }
456
457 /* Required to be a ScalarLatticeValue. */
461 return *this;
462 }
463
464 bool operator==(const ExpressionValue &rhs) const;
465
472 friend ExpressionValue
473 intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
474
481 friend ExpressionValue
482 join(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
483
484 // arithmetic ops
485
486 friend ExpressionValue
487 add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
488
489 friend ExpressionValue
490 sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
491
492 friend ExpressionValue
493 mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
494
495 friend ExpressionValue
496 div(llvm::SMTSolverRef solver, felt::DivFeltOp op, const ExpressionValue &lhs,
497 const ExpressionValue &rhs);
498
499 friend ExpressionValue
500 mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
501
502 friend ExpressionValue
503 cmp(llvm::SMTSolverRef solver, boolean::CmpOp op, const ExpressionValue &lhs,
504 const ExpressionValue &rhs);
505
515 llvm::SMTSolverRef solver, mlir::Operation *op, const ExpressionValue &lhs,
516 const ExpressionValue &rhs
517 );
518
519 friend ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val);
520
521 friend ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val);
522
523 friend ExpressionValue
524 fallbackUnaryOp(llvm::SMTSolverRef solver, mlir::Operation *op, const ExpressionValue &val);
525
526 /* Utility */
527
528 void print(mlir::raw_ostream &os) const;
529
530 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ExpressionValue &e) {
531 e.print(os);
532 return os;
533 }
534
535 struct Hash {
536 unsigned operator()(const ExpressionValue &e) const {
537 return Interval::Hash {}(e.i) ^ llvm::hash_value(e.expr);
538 }
539 };
540
541private:
542 Interval i;
543 llvm::SMTExprRef expr;
544};
545
546/* IntervalAnalysisLatticeValue */
547
549 : public dataflow::AbstractLatticeValue<IntervalAnalysisLatticeValue, ExpressionValue> {
550public:
551 using AbstractLatticeValue::AbstractLatticeValue;
552};
553
554/* IntervalAnalysisLattice */
555
557
561public:
563 // Map mlir::Values to LatticeValues
564 using ValueMap = mlir::DenseMap<mlir::Value, LatticeValue>;
565 // Expression to interval map for convenience.
566 using ExpressionIntervals = mlir::DenseMap<llvm::SMTExprRef, Interval>;
567 // Tracks all constraints and assignments in insertion order
568 using ConstraintSet = llvm::SetVector<ExpressionValue>;
569
570 using AbstractDenseLattice::AbstractDenseLattice;
571
572 mlir::ChangeResult join(const AbstractDenseLattice &other) override;
573
574 mlir::ChangeResult meet(const AbstractDenseLattice &rhs) override {
575 llvm::report_fatal_error("IntervalDataFlowAnalysis::meet : unsupported");
576 return mlir::ChangeResult::NoChange;
577 }
578
579 void print(mlir::raw_ostream &os) const override;
580
581 mlir::FailureOr<LatticeValue> getValue(mlir::Value v) const;
582
583 mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e);
584
585 mlir::ChangeResult addSolverConstraint(ExpressionValue e);
586
587 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const IntervalAnalysisLattice &l) {
588 l.print(os);
589 return os;
590 }
591
592 const ConstraintSet &getConstraints() const { return constraints; }
593
594 mlir::FailureOr<Interval> findInterval(llvm::SMTExprRef expr) const;
595
596private:
597 ValueMap valMap;
598 ConstraintSet constraints;
599 ExpressionIntervals intervals;
600};
601
602/* IntervalDataFlowAnalysis */
603
605 : public dataflow::DenseForwardDataFlowAnalysis<IntervalAnalysisLattice> {
607 using Lattice = IntervalAnalysisLattice;
608 using LatticeValue = IntervalAnalysisLattice::LatticeValue;
609
610 // Map fields to their symbols
611 using SymbolMap = mlir::DenseMap<ConstrainRef, llvm::SMTExprRef>;
612
613public:
615 mlir::DataFlowSolver &solver, llvm::SMTSolverRef smt, const Field &f
616 )
617 : Base::DenseForwardDataFlowAnalysis(solver), dataflowSolver(solver), smtSolver(smt),
618 field(f) {}
619
621 mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before,
622 Lattice *after
623 ) override;
624
625 void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override;
626
631 llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r);
632
633private:
634 mlir::DataFlowSolver &dataflowSolver;
635 llvm::SMTSolverRef smtSolver;
636 SymbolMap refSymbols;
637 std::reference_wrapper<const Field> field;
638
639 void setToEntryState(Lattice *lattice) override {
640 // initial state should be empty, so do nothing here
641 }
642
643 llvm::SMTExprRef createFeltSymbol(const ConstrainRef &r) const;
644
645 llvm::SMTExprRef createFeltSymbol(mlir::Value val) const;
646
647 llvm::SMTExprRef createFeltSymbol(const char *name) const;
648
649 bool isConstOp(mlir::Operation *op) const {
650 return mlir::isa<
651 felt::FeltConstantOp, mlir::arith::ConstantIndexOp, mlir::arith::ConstantIntOp>(op);
652 }
653
654 llvm::APSInt getConst(mlir::Operation *op) const;
655
656 llvm::SMTExprRef createConstBitvectorExpr(llvm::APSInt v) const {
657 return smtSolver->mkBitvector(v, field.get().bitWidth());
658 }
659
660 llvm::SMTExprRef createConstBoolExpr(bool v) const {
661 return smtSolver->mkBitvector(mlir::APSInt((int)v), field.get().bitWidth());
662 }
663
664 bool isArithmeticOp(mlir::Operation *op) const {
665 return mlir::isa<
666 felt::AddFeltOp, felt::SubFeltOp, felt::MulFeltOp, felt::DivFeltOp, felt::ModFeltOp,
667 felt::NegFeltOp, felt::InvFeltOp, felt::AndFeltOp, felt::OrFeltOp, felt::XorFeltOp,
668 felt::NotFeltOp, felt::ShlFeltOp, felt::ShrFeltOp, boolean::CmpOp>(op);
669 }
670
671 ExpressionValue
672 performBinaryArithmetic(mlir::Operation *op, const LatticeValue &a, const LatticeValue &b);
673
674 ExpressionValue performUnaryArithmetic(mlir::Operation *op, const LatticeValue &a);
675
682 mlir::ChangeResult
683 applyInterval(mlir::Operation *originalOp, Lattice *after, mlir::Value val, Interval newInterval);
684
685 bool isBoolOp(mlir::Operation *op) const {
686 return mlir::isa<boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(
687 op
688 );
689 }
690
691 bool isConversionOp(mlir::Operation *op) const {
692 return mlir::isa<cast::IntToFeltOp, cast::FeltToIndexOp>(op);
693 }
694
695 bool isApplyMapOp(mlir::Operation *op) const { return mlir::isa<polymorphic::ApplyMapOp>(op); }
696
697 bool isAssertOp(mlir::Operation *op) const { return mlir::isa<boolean::AssertOp>(op); }
698
699 bool isReadOp(mlir::Operation *op) const {
700 return mlir::isa<component::FieldReadOp, polymorphic::ConstReadOp, array::ReadArrayOp>(op);
701 }
702
703 bool isWriteOp(mlir::Operation *op) const {
704 return mlir::isa<component::FieldWriteOp, array::WriteArrayOp, array::InsertArrayOp>(op);
705 }
706
707 bool isArrayLengthOp(mlir::Operation *op) const { return mlir::isa<array::ArrayLengthOp>(op); }
708
709 bool isEmitOp(mlir::Operation *op) const {
710 return mlir::isa<constrain::EmitEqualityOp, constrain::EmitContainmentOp>(op);
711 }
712
713 bool isCreateOp(mlir::Operation *op) const {
714 return mlir::isa<component::CreateStructOp, array::CreateArrayOp>(op);
715 }
716
717 bool isExtractArrayOp(mlir::Operation *op) const { return mlir::isa<array::ExtractArrayOp>(op); }
718
719 bool isDefinitionOp(mlir::Operation *op) const {
720 return mlir::isa<
721 component::StructDefOp, function::FuncDefOp, component::FieldDefOp, global::GlobalDefOp,
722 mlir::ModuleOp>(op);
723 }
724
725 bool isCallOp(mlir::Operation *op) const { return mlir::isa<function::CallOp>(op); }
726
727 bool isReturnOp(mlir::Operation *op) const { return mlir::isa<function::ReturnOp>(op); }
728
734 bool isConsideredOp(mlir::Operation *op) const {
735 return isConstOp(op) || isArithmeticOp(op) || isBoolOp(op) || isConversionOp(op) ||
736 isApplyMapOp(op) || isAssertOp(op) || isReadOp(op) || isWriteOp(op) ||
737 isArrayLengthOp(op) || isEmitOp(op) || isCreateOp(op) || isDefinitionOp(op) ||
738 isCallOp(op) || isReturnOp(op) || isExtractArrayOp(op);
739 }
740};
741
742/* StructIntervals */
743
747 llvm::SMTSolverRef smtSolver;
748 std::reference_wrapper<const Field> field;
749
750 llvm::SMTExprRef getSymbol(const ConstrainRef &r) { return intervalDFA->getOrCreateSymbol(r); }
751 const Field &getField() const { return field.get(); }
752};
753
754class StructIntervals {
755public:
765 static mlir::FailureOr<StructIntervals> compute(
766 mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver,
767 mlir::AnalysisManager &am, IntervalAnalysisContext &ctx
768 ) {
769 StructIntervals si(mod, s);
770 if (si.computeIntervals(solver, am, ctx).failed()) {
771 return mlir::failure();
772 }
773 return si;
774 }
775
776 mlir::LogicalResult computeIntervals(
777 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx
778 );
779
780 void print(mlir::raw_ostream &os, bool withConstraints = false) const;
781
782 const llvm::MapVector<ConstrainRef, Interval> &getIntervals() const {
783 return constrainFieldRanges;
784 }
785
786 const llvm::SetVector<ExpressionValue> getSolverConstraints() const {
787 return constrainSolverConstraints;
788 }
789
790 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const StructIntervals &si) {
791 si.print(os);
792 return os;
793 }
794
795private:
796 mlir::ModuleOp mod;
797 component::StructDefOp structDef;
798 llvm::SMTSolverRef smtSolver;
799 // llvm::MapVector keeps insertion order for consistent iteration
800 llvm::MapVector<ConstrainRef, Interval> constrainFieldRanges;
801 // llvm::SetVector for the same reasons as above
802 llvm::SetVector<ExpressionValue> constrainSolverConstraints;
803
804 StructIntervals(mlir::ModuleOp m, component::StructDefOp s) : mod(m), structDef(s) {}
805};
806
807/* StructIntervalAnalysis */
808
810
811class StructIntervalAnalysis : public StructAnalysis<StructIntervals, IntervalAnalysisContext> {
812public:
814
815 mlir::LogicalResult runAnalysis(
816 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager,
818 ) override {
819 auto res =
820 StructIntervals::compute(getModule(), getStruct(), solver, moduleAnalysisManager, ctx);
821 if (mlir::failed(res)) {
822 return mlir::failure();
823 }
824 setResult(std::move(*res));
825 return mlir::success();
826 }
827};
828
829/* ModuleIntervalAnalysis */
830
832 : public ModuleAnalysis<StructIntervals, IntervalAnalysisContext, StructIntervalAnalysis> {
833
834public:
835 ModuleIntervalAnalysis(mlir::Operation *op)
836 : ModuleAnalysis(op), smtSolver(llvm::CreateZ3Solver()), field(std::nullopt) {}
837
838 void setField(const Field &f) { field = f; }
839
840protected:
841 void initializeSolver(mlir::DataFlowSolver &solver) override {
842 ensure(field.has_value(), "field not set, could not generate analysis context");
843 (void)solver.load<ConstrainRefAnalysis>();
844 auto smtSolverRef = smtSolver;
845 intervalDFA = solver.load<IntervalDataFlowAnalysis, llvm::SMTSolverRef, const Field &>(
846 std::move(smtSolverRef), field.value()
847 );
848 }
849
851 ensure(field.has_value(), "field not set, could not generate analysis context");
852 return {
853 .intervalDFA = intervalDFA,
854 .smtSolver = smtSolver,
855 .field = field.value(),
856 };
857 }
858
859private:
860 llvm::SMTSolverRef smtSolver;
861 IntervalDataFlowAnalysis *intervalDFA;
862 std::optional<std::reference_wrapper<const Field>> field;
863};
864
865} // namespace llzk
866
867namespace llvm {
868
869template <> struct DenseMapInfo<llzk::ExpressionValue> {
870
871 static SMTExprRef getEmptyExpr() {
872 static auto emptyPtr = reinterpret_cast<SMTExprRef>(1);
873 return emptyPtr;
874 }
875 static SMTExprRef getTombstoneExpr() {
876 static auto tombstonePtr = reinterpret_cast<SMTExprRef>(2);
877 return tombstonePtr;
878 }
879
886 static unsigned getHashValue(const llzk::ExpressionValue &e) {
887 return llzk::ExpressionValue::Hash {}(e);
888 }
889 static bool isEqual(const llzk::ExpressionValue &lhs, const llzk::ExpressionValue &rhs) {
890 if (lhs.getExpr() == getEmptyExpr() || lhs.getExpr() == getTombstoneExpr() ||
891 rhs.getExpr() == getEmptyExpr() || rhs.getExpr() == getTombstoneExpr()) {
892 return lhs.getExpr() == rhs.getExpr();
893 }
894 return lhs == rhs;
895 }
896};
897
898} // namespace llvm
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
Convenience classes for a frequent pattern of dataflow analysis used in LLZK, where an analysis is ru...
This file implements (LLZK-tailored) dense data-flow analysis using the data-flow analysis framework.
The dataflow analysis that computes the set of references that LLZK operations use and produce.
Defines a reference to a llzk object within a constrain function call.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
friend ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
const Interval & getInterval() const
friend ExpressionValue cmp(llvm::SMTSolverRef solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
friend ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, mlir::Operation *op, const ExpressionValue &val)
friend ExpressionValue div(llvm::SMTSolverRef solver, felt::DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
friend ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
Computes a solver expression based on the operation, but computes a fallback interval (which is just ...
ExpressionValue & join(const ExpressionValue &rhs)
Fold two expressions together when overapproximating array elements.
ExpressionValue(llvm::SMTExprRef exprRef, Interval interval)
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, llvm::APSInt singleVal)
friend ExpressionValue join(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the union of the lhs and rhs intervals, and create a solver expression that constrains both s...
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ExpressionValue &e)
const Field & getField() const
friend ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
Information about the prime finite field used for the interval analysis.
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
llvm::APSInt felt(unsigned i) const
Returns i as a field element.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::APSInt zero() const
Returns 0 at the bitwidth of the field.
friend bool operator==(const Field &lhs, const Field &rhs)
Field()=delete
Field(const Field &)=default
Field(Field &&)=default
llvm::APSInt half() const
Returns p / 2.
unsigned bitWidth() const
Field & operator=(const Field &)=default
llvm::APSInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
static const Field & getField(const char *fieldName)
Get a Field from a given field name string.
llvm::APSInt prime() const
For the prime field p, returns p.
llvm::APSInt reduce(llvm::APSInt i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
Maps mlir::Values to LatticeValues.
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const IntervalAnalysisLattice &l)
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
llvm::SetVector< ExpressionValue > ConstraintSet
IntervalAnalysisLatticeValue LatticeValue
mlir::DenseMap< mlir::Value, LatticeValue > ValueMap
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::DenseMap< llvm::SMTExprRef, Interval > ExpressionIntervals
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
mlir::ChangeResult meet(const AbstractDenseLattice &rhs) override
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
IntervalDataFlowAnalysis(mlir::DataFlowSolver &solver, llvm::SMTSolverRef smt, const Field &f)
Intervals over a finite field.
bool isEmpty() const
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const Interval &i)
static constexpr std::array< std::string_view, 7 > TypeNames
bool isTypeA() const
Interval intersect(const Interval &rhs) const
Intersect.
static std::string_view TypeName(Type t)
llvm::APSInt rhs() const
bool isTypeC() const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
bool isTypeB() const
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
bool isDegenerate() const
void print(mlir::raw_ostream &os) const
const Field & getField() const
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
static Interval TypeC(const Field &f, llvm::APSInt a, llvm::APSInt b)
bool isNotEmpty() const
static Interval TypeF(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeB(const Field &f, llvm::APSInt a, llvm::APSInt b)
bool operator==(const Interval &rhs) const
bool isTypeF() const
llvm::APSInt width() const
friend Interval operator*(const Interval &lhs, const Interval &rhs)
llvm::APSInt lhs() const
static Interval Empty(const Field &f)
static bool areOneOf(const Interval &a, const Interval &b)
static Interval Degenerate(const Field &f, llvm::APSInt val)
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
friend mlir::FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Returns failure if a division-by-zero is encountered.
static Interval TypeA(const Field &f, llvm::APSInt a, llvm::APSInt b)
friend Interval operator+(const Interval &lhs, const Interval &rhs)
bool isEntire() const
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
friend Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
ModuleIntervalAnalysis(mlir::Operation *op)
IntervalAnalysisContext getContext() override
Create and return a valid Context object.
void initializeSolver(mlir::DataFlowSolver &solver) override
Initialize the shared dataflow solver with any common analyses required by the contained struct analy...
void setField(const Field &f)
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, IntervalAnalysisContext &ctx) override
Perform the analysis and construct the Result output.
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
const llvm::SetVector< ExpressionValue > getSolverConstraints() const
const llvm::MapVector< ConstrainRef, Interval > & getIntervals() const
void print(mlir::raw_ostream &os, bool withConstraints=false) const
static mlir::FailureOr< StructIntervals > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
Compute the struct intervals.
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const StructIntervals &si)
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
UnreducedInterval operator-() const
friend UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
UnreducedInterval(llvm::APInt x, llvm::APInt y)
bool isEmpty() const
Returns true iff width() is zero.
UnreducedInterval(llvm::APSInt x, llvm::APSInt y)
llvm::APSInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
friend std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval(uint64_t x, uint64_t y)
This constructor is primarily for convenience for unit tests.
bool overlaps(const UnreducedInterval &rhs) const
llvm::APSInt getRHS() const
friend llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const UnreducedInterval &ui)
llvm::APSInt getLHS() const
friend bool operator==(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
void print(llvm::raw_ostream &os) const
static size_t getMaxBitWidth(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
A utility method to determine the largest bitwidth among arms of two UnreducedIntervals.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
friend UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
LLZK: This class has been ported so that it can inherit from our port of the AbstractDenseForwardData...
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static unsigned getHashValue(const llzk::ExpressionValue &e)
static bool isEqual(const llzk::ExpressionValue &lhs, const llzk::ExpressionValue &rhs)
static llzk::ExpressionValue getTombstoneKey()
static llzk::ExpressionValue getEmptyKey()
unsigned operator()(const ExpressionValue &e) const
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA
llvm::SMTExprRef getSymbol(const ConstrainRef &r)
unsigned operator()(const Interval &i) const