15#include <llvm/ADT/TypeSwitch.h>
29Field::Field(std::string_view primeStr) : primeMod(llvm::APSInt(primeStr)) {
30 halfPrime = (primeMod + felt(1)) / felt(2);
34 static llvm::DenseMap<llvm::StringRef, Field> knownFields;
35 static std::once_flag fieldsInit;
36 std::call_once(fieldsInit, initKnownFields, knownFields);
38 if (
auto it = knownFields.find(fieldName); it != knownFields.end()) {
41 llvm::report_fatal_error(
"field \"" + mlir::Twine(fieldName) +
"\" is unsupported");
44void Field::initKnownFields(llvm::DenseMap<llvm::StringRef, Field> &knownFields) {
46 knownFields.try_emplace(
48 Field(
"21888242871839275222246405745257275088696311157297823662689037894645226208583")
50 knownFields.try_emplace(
"bn254", knownFields.at(
"bn128"));
52 knownFields.try_emplace(
"babybear",
Field(
"2013265921"));
54 knownFields.try_emplace(
"goldilocks",
Field(
"18446744069414584321"));
56 knownFields.try_emplace(
"mersenne31",
Field(
"2147483647"));
60 unsigned maxBits = std::max(i.getBitWidth(),
bitWidth());
61 llvm::APSInt m = (i.extend(maxBits) %
prime().extend(maxBits)).trunc(
bitWidth());
69 auto ap = llvm::APSInt(llvm::APInt(
bitWidth(), i));
85 if ((rhs - lhs).isZero()) {
89 const auto &half = field.
half();
91 if (lhs.ult(half) && rhs.ult(half)) {
93 }
else if (lhs.ult(half)) {
99 if (lhs.uge(half) && rhs.ult(half)) {
121 auto one = llvm::APSInt(llvm::APInt(a.getBitWidth(), 1));
137 auto one = llvm::APSInt(llvm::APInt(a.getBitWidth(), 1));
166 auto minVal =
safeMin({v1, v2, v3, v4});
167 auto maxVal =
safeMax({v1, v2, v3, v4});
178 return std::strong_ordering::less;
181 return std::strong_ordering::greater;
183 return std::strong_ordering::equal;
190 w = llvm::APSInt::getUnsigned(0);
195 ensure(
safeGe(w, llvm::APSInt::getUnsigned(0)),
"cannot have negative width");
228 lhs.getField() ==
rhs.getField(),
"interval operations across differing fields is unsupported"
233 if (
lhs.isEntire() ||
rhs.isEntire()) {
242 if (
lhs.isDegenerate() ||
rhs.isDegenerate()) {
243 return lhs.toUnreduced().doUnion(
rhs.toUnreduced()).reduce(f);
253 auto lhsUnred =
lhs.firstUnreduced();
254 auto opt1 =
rhs.firstUnreduced().doUnion(lhsUnred);
255 auto opt2 =
rhs.secondUnreduced().doUnion(lhsUnred);
256 if (opt1.width() <= opt2.width()) {
257 return opt1.reduce(f);
259 return opt2.reduce(f);
262 return lhs.firstUnreduced().doUnion(
rhs.firstUnreduced()).reduce(f);
265 return lhs.secondUnreduced().doUnion(
rhs.firstUnreduced()).reduce(f);
277 llvm::report_fatal_error(
"unhandled join case");
284 lhs.getField() ==
rhs.getField(),
"interval operations across differing fields is unsupported"
288 if (
lhs.isEmpty() ||
rhs.isEmpty()) {
291 if (
lhs.isEntire()) {
294 if (
rhs.isEntire()) {
297 if (
lhs.isDegenerate() ||
rhs.isDegenerate()) {
298 return lhs.toUnreduced().intersect(
rhs.toUnreduced()).reduce(field.get());
305 auto maxA = std::max(
lhs.a,
rhs.a);
306 auto minB = std::min(
lhs.b,
rhs.b);
317 return lhs.firstUnreduced().intersect(
rhs.firstUnreduced()).reduce(field.get());
320 return lhs.secondUnreduced().intersect(
rhs.firstUnreduced()).reduce(field.get());
323 auto rhsUnred =
rhs.firstUnreduced();
324 auto opt1 =
lhs.firstUnreduced().intersect(rhsUnred).reduce(field.get());
325 auto opt2 =
lhs.secondUnreduced().intersect(rhsUnred).reduce(field.get());
326 ensure(!opt1.isEntire() && !opt2.isEntire(),
"impossible intersection");
327 if (opt1.isEmpty()) {
330 if (opt2.isEmpty()) {
333 return opt1.join(opt2);
340 return rhs.intersect(
lhs);
353 const Field &f = field.get();
362 if (other.a == f.
zero()) {
365 if (other.b == f.
maxVal()) {
411 ensure(
lhs.field.get() ==
rhs.field.get(),
"cannot add intervals in different fields");
412 return (
lhs.firstUnreduced() +
rhs.firstUnreduced()).reduce(
lhs.field.get());
418 ensure(
lhs.field.get() ==
rhs.field.get(),
"cannot multiply intervals in different fields");
419 const auto &field =
lhs.field.get();
421 if (
lhs == zeroInterval ||
rhs == zeroInterval) {
424 if (
lhs.isEmpty() ||
rhs.isEmpty()) {
427 if (
lhs.isEntire() ||
rhs.isEntire()) {
432 return (
lhs.secondUnreduced() *
rhs.secondUnreduced()).reduce(field);
434 return (
lhs.firstUnreduced() *
rhs.firstUnreduced()).reduce(field);
438 ensure(
lhs.getField() ==
rhs.getField(),
"cannot divide intervals in different fields");
439 const auto &field =
rhs.getField();
440 if (
rhs.width() > field.one()) {
443 if (
rhs.a.isZero()) {
451 lhs.getField() ==
rhs.getField(),
"interval operations across differing fields is unsupported"
453 const auto &field =
rhs.getField();
460 os <<
'(' << a <<
')';
462 os <<
":[ " << a <<
", " << b <<
" ]";
469 if (expr ==
nullptr && rhs.expr ==
nullptr) {
472 if (expr ==
nullptr || rhs.expr ==
nullptr) {
475 return i == rhs.i && *expr == *rhs.expr;
481 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
488 res.i = lhs.i + rhs.i;
489 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
496 res.i = lhs.i - rhs.i;
497 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
504 res.i = lhs.i * rhs.i;
505 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
513 auto divRes = lhs.i / rhs.i;
514 if (failed(divRes)) {
516 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
517 " Range of division result will be treated as unbounded."
523 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
530 res.i = lhs.i % rhs.i;
531 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
540 case FeltCmpPredicate::EQ:
541 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
544 case FeltCmpPredicate::NE:
545 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
547 case FeltCmpPredicate::LT:
548 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
550 case FeltCmpPredicate::LE:
551 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
553 case FeltCmpPredicate::GT:
554 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
556 case FeltCmpPredicate::GE:
557 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
568 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
570 .Case<OrFeltOp>([&](
OrFeltOp _) {
return solver->mkBVOr(lhs.expr, rhs.expr); })
571 .Case<XorFeltOp>([&](
XorFeltOp _) {
return solver->mkBVXor(lhs.expr, rhs.expr); })
572 .Case<ShlFeltOp>([&](
ShlFeltOp _) {
return solver->mkBVShl(lhs.expr, rhs.expr); })
574 return solver->mkBVLshr(lhs.expr, rhs.expr);
575 }).Default([&](Operation *unsupported) {
576 llvm::report_fatal_error(
577 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
588 res.expr = solver->mkBVNeg(val.expr);
603 res.expr = solver->mkBVNot(val.expr);
612 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
618 llvm::SMTExprRef invSym = field.
createSymbol(solver, symName.c_str());
619 llvm::SMTExprRef one = solver->mkBitvector(field.
one(), field.
bitWidth());
620 llvm::SMTExprRef prime = solver->mkBitvector(field.
prime(), field.
bitWidth());
621 llvm::SMTExprRef mult = solver->mkBVMul(val.
getExpr(), invSym);
622 llvm::SMTExprRef
mod = solver->mkBVURem(mult, prime);
623 llvm::SMTExprRef constraint = solver->mkEqual(
mod, one);
624 solver->addConstraint(constraint);
626 }).Default([&](Operation *unsupported) {
627 llvm::report_fatal_error(
628 "no fallback provided for " + mlir::Twine(op->getName().getStringRef())
640 os <<
"<null expression>";
643 os <<
" ( interval: " << i <<
" )";
651 llvm::report_fatal_error(
"invalid join lattice type");
653 ChangeResult res = ChangeResult::NoChange;
654 for (
auto &[k, v] : rhs->valMap) {
655 auto it = valMap.find(k);
656 if (it == valMap.end() || it->second != v) {
658 res |= ChangeResult::Change;
661 for (
auto &v : rhs->constraints) {
662 if (!constraints.contains(v)) {
663 constraints.insert(v);
664 res |= ChangeResult::Change;
667 for (
auto &[e, i] : rhs->intervals) {
668 auto it = intervals.find(e);
669 if (it == intervals.end() || it->second != i) {
671 res |= ChangeResult::Change;
678 os <<
"IntervalAnalysisLattice { ";
679 for (
auto &[ref, val] : valMap) {
680 os <<
"\n (valMap) " << ref <<
" := " << val;
682 for (
auto &[expr, interval] : intervals) {
683 os <<
"\n (intervals) ";
685 os <<
" in " << interval;
687 if (!valMap.empty()) {
694 auto it = valMap.find(v);
695 if (it == valMap.end()) {
703 if (valMap[v] == val) {
704 return ChangeResult::NoChange;
708 return ChangeResult::Change;
712 if (!constraints.contains(e)) {
713 constraints.insert(e);
714 return ChangeResult::Change;
716 return ChangeResult::NoChange;
720 auto it = intervals.find(expr);
721 if (it != intervals.end()) {
739 if (action == dataflow::CallControlFlowAction::EnterCallee) {
743 setToEntryState(after);
748 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
752 if (
auto *prev = call->getPrevNode()) {
757 ensure(beforeCall,
"could not get prior lattice");
760 propagateIfChanged(after, after->
join(*beforeCall));
766 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
777 Operation *op,
const Lattice &before, Lattice *after
779 ChangeResult changed = after->
join(before);
781 llvm::SmallVector<LatticeValue> operandVals;
784 ensure(constrainRefLattice,
"failed to get lattice");
786 for (
auto &operand : op->getOpOperands()) {
787 auto val = operand.get();
789 auto priorState = before.
getValue(val);
790 if (succeeded(priorState)) {
791 operandVals.push_back(*priorState);
798 Type valTy = val.getType();
799 if (mlir::isa<ArrayType, StructType>(valTy) && !
isSignalType(valTy)) {
800 operandVals.push_back(LatticeValue());
805 ensure(refSet.
isScalar(),
"should have ruled out array values already");
812 debug::Appender(warning
814 << val <<
" is empty; defining operation is unsupported by constrain ref analysis";
815 op->emitWarning(warning);
816 propagateIfChanged(after, changed);
820 debug::Appender(warning) <<
"operand " << val <<
" is not a single value " << refSet
821 <<
", overapproximating";
822 op->emitWarning(warning);
826 changed |= after->
setValue(val, anyVal);
827 operandVals.emplace_back(anyVal);
831 changed |= after->
setValue(val, exprVal);
832 operandVals.emplace_back(exprVal);
837 if (!isConsideredOp(op)) {
838 op->emitWarning(
"unconsidered operation type, analysis may be incomplete");
842 auto constVal = getConst(op);
843 auto expr = createConstBitvectorExpr(constVal);
845 changed |= after->
setValue(op->getResult(0), latticeVal);
846 }
else if (isArithmeticOp(op)) {
847 ensure(operandVals.size() <= 2,
"arithmetic op with the wrong number of operands");
849 if (operandVals.size() == 2) {
850 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
852 result = performUnaryArithmetic(op, operandVals[0]);
855 changed |= after->
setValue(op->getResult(0), result);
856 }
else if (
EmitEqualityOp emitEq = mlir::dyn_cast<EmitEqualityOp>(op)) {
857 ensure(operandVals.size() == 2,
"constraint op with the wrong number of operands");
858 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
865 changed |= applyInterval(emitEq, after, lhsVal, constraint.
getInterval());
866 changed |= applyInterval(emitEq, after, rhsVal, constraint.
getInterval());
868 }
else if (
AssertOp assertOp = mlir::dyn_cast<AssertOp>(op)) {
869 ensure(operandVals.size() == 1,
"assert op with the wrong number of operands");
872 changed |= applyInterval(
873 assertOp, after, assertOp.getCondition(),
877 auto assertExpr = operandVals[0].getScalarValue();
879 }
else if (
auto readf = mlir::dyn_cast<FieldReadOp>(op);
880 readf &&
isSignalType(readf.getComponent().getType())) {
883 changed |= after->
setValue(readf.getVal(), operandVals[0].getScalarValue());
884 }
else if (!isReadOp(op)
889 && !isDefinitionOp(op)
891 !mlir::isa<CreateStructOp>(op)
893 op->emitWarning(
"unhandled operation, analysis may be incomplete");
896 propagateIfChanged(after, changed);
900 auto it = refSymbols.find(r);
901 if (it != refSymbols.end()) {
904 auto sym = createFeltSymbol(r);
909llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const ConstrainRef &r)
const {
913llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v)
const {
914 return createFeltSymbol(buildStringViaPrint(v).c_str());
917llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const char *name)
const {
918 return field.get().createSymbol(smtSolver, name);
921llvm::APSInt IntervalDataFlowAnalysis::getConst(Operation *op)
const {
922 ensure(isConstOp(op),
"op is not a const op");
924 llvm::APInt fieldConst =
925 TypeSwitch<Operation *, llvm::APInt>(op)
926 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
927 llvm::APSInt constOpVal(feltConst.getValueAttr().getValue());
928 return field.get().reduce(constOpVal);
930 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
931 return llvm::APInt(field.get().bitWidth(), indexConst.value());
933 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
934 return llvm::APInt(field.get().bitWidth(), intConst.value());
935 }).Default([&](Operation *illegalOp) {
937 debug::Appender(err) <<
"unhandled getConst case: " << *illegalOp;
938 llvm::report_fatal_error(Twine(err));
939 return llvm::APInt();
941 return llvm::APSInt(fieldConst);
944ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
945 Operation *op,
const LatticeValue &a,
const LatticeValue &b
947 ensure(isArithmeticOp(op),
"is not arithmetic op");
949 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
950 ensure(lhs.getExpr(),
"cannot perform arithmetic over null lhs smt expr");
951 ensure(rhs.getExpr(),
"cannot perform arithmetic over null rhs smt expr");
953 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
954 .Case<AddFeltOp>([&](AddFeltOp _) {
return add(smtSolver, lhs, rhs); })
955 .Case<SubFeltOp>([&](SubFeltOp _) {
return sub(smtSolver, lhs, rhs); })
956 .Case<MulFeltOp>([&](MulFeltOp _) {
return mul(smtSolver, lhs, rhs); })
957 .Case<DivFeltOp>([&](DivFeltOp divOp) {
return div(smtSolver, divOp, lhs, rhs); })
958 .Case<ModFeltOp>([&](ModFeltOp _) {
return mod(smtSolver, lhs, rhs); })
959 .Case<CmpOp>([&](CmpOp cmpOp) {
960 return cmp(smtSolver, cmpOp, lhs, rhs);
961 }).Default([&](Operation *unsupported) {
962 unsupported->emitWarning(
963 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
968 ensure(res.getExpr(),
"arithmetic produced null smt expr");
973IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op,
const LatticeValue &a) {
974 ensure(isArithmeticOp(op),
"is not arithmetic op");
976 auto val = a.getScalarValue();
977 ensure(val.getExpr(),
"cannot perform arithmetic over null smt expr");
979 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
980 .Case<NegFeltOp>([&](NegFeltOp _) {
return neg(smtSolver, val); })
981 .Case<NotFeltOp>([&](NotFeltOp _) {
982 return notOp(smtSolver, val);
983 }).Default([&](Operation *unsupported) {
984 unsupported->emitWarning(
985 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
990 ensure(res.getExpr(),
"arithmetic produced null smt expr");
994ChangeResult IntervalDataFlowAnalysis::applyInterval(
995 Operation *originalOp, Lattice *after, Value val, Interval newInterval
997 auto latValRes = after->getValue(val);
998 if (failed(latValRes)) {
1000 return ChangeResult::NoChange;
1002 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
1003 ChangeResult res = after->setValue(val, newLatticeVal);
1006 Lattice *valLattice =
nullptr;
1007 if (
auto valOp = val.getDefiningOp()) {
1011 if (
auto prev = valOp->getPrevNode()) {
1012 valLattice = getOrCreate<Lattice>(prev);
1014 valLattice = getOrCreate<Lattice>(valOp->getBlock());
1016 }
else if (
auto blockArg = mlir::dyn_cast<BlockArgument>(val)) {
1017 valLattice = getOrCreate<Lattice>(blockArg.getOwner());
1019 valLattice = getOrCreate<Lattice>(val);
1021 ensure(valLattice,
"val should have a lattice");
1022 if (valLattice != after) {
1023 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
1027 Operation *definingOp = val.getDefiningOp();
1032 const Field &f = field.get();
1040 auto cmpCase = [&](CmpOp cmpOp) {
1043 Interval maxInterval = Interval::Boolean(f);
1045 newInterval.intersect(maxInterval).isNotEmpty(),
1046 "new interval for CmpOp outside of allowed boolean range or is empty"
1048 if (!newInterval.isDegenerate()) {
1050 return ChangeResult::NoChange;
1053 bool cmpTrue = newInterval.rhs() == f.one();
1055 Value lhs = cmpOp->getOperand(0), rhs = cmpOp->getOperand(1);
1056 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
1057 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
1058 return ChangeResult::NoChange;
1060 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
1061 rhsExpr = rhsLatValRes->getScalarValue();
1063 Interval newLhsInterval, newRhsInterval;
1064 const Interval &lhsInterval = lhsExpr.getInterval();
1065 const Interval &rhsInterval = rhsExpr.getInterval();
1069 auto eqCase = [&]() {
1070 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1071 (pred == FeltCmpPredicate::NE && !cmpTrue);
1073 auto neCase = [&]() {
1074 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1075 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1077 auto ltCase = [&]() {
1078 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1079 (pred == FeltCmpPredicate::GE && !cmpTrue);
1081 auto leCase = [&]() {
1082 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1083 (pred == FeltCmpPredicate::GT && !cmpTrue);
1085 auto gtCase = [&]() {
1086 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1087 (pred == FeltCmpPredicate::LE && !cmpTrue);
1089 auto geCase = [&]() {
1090 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1091 (pred == FeltCmpPredicate::LT && !cmpTrue);
1096 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1097 }
else if (neCase()) {
1098 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1101 newLhsInterval = newRhsInterval = Interval::Empty(f);
1104 newLhsInterval = lhsInterval;
1105 newRhsInterval = rhsInterval;
1107 }
else if (ltCase()) {
1108 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1109 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1110 }
else if (leCase()) {
1111 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1112 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1113 }
else if (gtCase()) {
1114 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1115 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1116 }
else if (geCase()) {
1117 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1118 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1120 cmpOp->emitWarning(
"unhandled cmp predicate");
1121 return ChangeResult::NoChange;
1125 return applyInterval(originalOp, after, lhs, newLhsInterval) |
1126 applyInterval(originalOp, after, rhs, newRhsInterval);
1131 auto mulCase = [&](MulFeltOp mulOp) {
1132 auto zeroInt = Interval::Degenerate(f, f.zero());
1133 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1135 return ChangeResult::NoChange;
1138 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
1139 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
1140 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
1141 return ChangeResult::NoChange;
1143 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
1144 rhsExpr = rhsLatValRes->getScalarValue();
1145 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1146 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1147 return applyInterval(originalOp, after, lhs, newLhsInterval) |
1148 applyInterval(originalOp, after, rhs, newRhsInterval);
1154 auto readfCase = [&](FieldReadOp readfOp) {
1155 Value comp = readfOp.getComponent();
1157 return applyInterval(originalOp, after, comp, newInterval);
1159 return ChangeResult::NoChange;
1166 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
1167 .Case<CmpOp>([&](CmpOp op) {
return cmpCase(op); })
1168 .Case<MulFeltOp>([&](MulFeltOp op) {
return mulCase(op); })
1169 .Case<FieldReadOp>([&](FieldReadOp op){
return readfCase(op); })
1170 .Default([&](Operation *_) {
return ChangeResult::NoChange; });
1192 for (
const auto &ref : ConstrainRef::getAllConstrainRefs(structDef)) {
1195 if (!ref.isScalar() && !ref.isSignal()) {
1199 if (
auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1203 auto constrainInterval = constrainLattice->
findInterval(symbol);
1204 if (succeeded(constrainInterval)) {
1205 constrainFieldRanges[ref] = *constrainInterval;
1215 os <<
"StructIntervals { ";
1216 if (constrainFieldRanges.empty()) {
1221 for (
auto &[ref, interval] : constrainFieldRanges) {
1222 os <<
"\n " << ref <<
" in " << interval;
1225 if (withConstraints) {
1226 os <<
"\n\n Solver Constraints { ";
1227 if (constrainSolverConstraints.empty()) {
1230 for (
const auto &e : constrainSolverConstraints) {
1232 e.getExpr()->print(os);
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
A value at a given point of the ConstrainRefLattice.
const ConstrainRef & getSingleValue() const
A lattice for use in dense analysis.
Defines a reference to a llzk object within a constrain function call.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
friend ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
const Interval & getInterval() const
friend ExpressionValue cmp(llvm::SMTSolverRef solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue div(llvm::SMTSolverRef solver, felt::DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
const Field & getField() const
friend ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
Information about the prime finite field used for the interval analysis.
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::APSInt zero() const
Returns 0 at the bitwidth of the field.
llvm::APSInt half() const
Returns p / 2.
unsigned bitWidth() const
llvm::APSInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
static const Field & getField(const char *fieldName)
Get a Field from a given field name string.
llvm::APSInt prime() const
For the prime field p, returns p.
llvm::APSInt reduce(llvm::APSInt i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
Maps mlir::Values to LatticeValues.
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
Intervals over a finite field.
Interval intersect(const Interval &rhs) const
Intersect.
static std::string_view TypeName(Type t)
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
bool isDegenerate() const
void print(mlir::raw_ostream &os) const
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
static Interval TypeC(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeF(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeB(const Field &f, llvm::APSInt a, llvm::APSInt b)
friend Interval operator*(const Interval &lhs, const Interval &rhs)
static Interval Empty(const Field &f)
static bool areOneOf(const Interval &a, const Interval &b)
static Interval Degenerate(const Field &f, llvm::APSInt val)
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
friend mlir::FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Returns failure if a division-by-zero is encountered.
static Interval TypeA(const Field &f, llvm::APSInt a, llvm::APSInt b)
friend Interval operator+(const Interval &lhs, const Interval &rhs)
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
friend Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
void print(mlir::raw_ostream &os, bool withConstraints=false) const
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
UnreducedInterval operator-() const
friend UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
bool isEmpty() const
Returns true iff width() is zero.
UnreducedInterval(llvm::APSInt x, llvm::APSInt y)
llvm::APSInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
friend std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
bool overlaps(const UnreducedInterval &rhs) const
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
friend UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
::llzk::boolean::FeltCmpPredicate getPredicate()
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
bool isSingleValue() const
const ScalarTy & getScalarValue() const
IntervalAnalysisLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingSub(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely subtract lhs and rhs, expanding the width of the result as necessary.
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingAdd(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely add lhs and rhs, expanding the width of the result as necessary.
llvm::APSInt safeMax(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, llvm::Twine errMsg)
llvm::APSInt expandingMul(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely multiple lhs and rhs, expanding the width of the result as necessary.
bool safeGt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
bool safeLt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
bool safeLe(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
bool safeGe(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
llvm::APSInt safeMin(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
llvm::SMTExprRef getSymbol(const ConstrainRef &r)