15#include <mlir/Dialect/SCF/IR/SCF.h>
17#include <llvm/ADT/TypeSwitch.h>
34 if (expr ==
nullptr && rhs.expr ==
nullptr) {
37 if (expr ==
nullptr || rhs.expr ==
nullptr) {
40 return i == rhs.i && *expr == *rhs.expr;
45 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
46 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
47 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.
getExpr(), one, zero);
54 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
61 res.i = lhs.i + rhs.i;
62 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
69 res.i = lhs.i - rhs.i;
70 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
77 res.i = lhs.i * rhs.i;
78 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
86 auto divRes = lhs.i / rhs.i;
89 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
90 " Range of division result will be treated as unbounded."
97 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
104 res.i = lhs.i % rhs.i;
105 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
112 res.i = lhs.i & rhs.i;
113 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
120 res.i = lhs.i << rhs.i;
121 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
128 res.i = lhs.i >> rhs.i;
129 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
140 case FeltCmpPredicate::EQ:
141 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
148 case FeltCmpPredicate::NE:
149 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
156 case FeltCmpPredicate::LT:
157 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
159 case FeltCmpPredicate::LE:
160 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
162 case FeltCmpPredicate::GT:
163 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
165 case FeltCmpPredicate::GE:
166 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
176 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
183 res.i =
boolOr(lhs.i, rhs.i);
184 res.expr = solver->mkOr(lhs.expr, rhs.expr);
193 res.expr = solver->mkAnd(
194 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
204 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
205 .Case<
OrFeltOp>([&](
auto) {
return solver->mkBVOr(lhs.expr, rhs.expr); })
206 .Case<XorFeltOp>([&](
auto) {
207 return solver->mkBVXor(lhs.expr, rhs.expr);
208 }).Default([&](
auto *unsupported) {
209 llvm::report_fatal_error(
210 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
221 res.expr = solver->mkBVNeg(val.expr);
228 res.expr = solver->mkBVNot(val.expr);
235 res.expr = solver->mkNot(val.expr);
244 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
250 llvm::SMTExprRef invSym = field.
createSymbol(solver, symName.c_str());
251 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.
bitWidth());
253 llvm::SMTExprRef mult = solver->mkBVMul(val.
getExpr(), invSym);
254 llvm::SMTExprRef
mod = solver->mkBVURem(mult, prime);
255 llvm::SMTExprRef constraint = solver->mkEqual(
mod, one);
256 solver->addConstraint(constraint);
258 }).Default([](Operation *unsupported) {
259 llvm::report_fatal_error(
260 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
272 os <<
"<null expression>";
275 os <<
" ( interval: " << i <<
" )";
283 llvm::report_fatal_error(
"invalid join lattice type");
285 ChangeResult res = ChangeResult::NoChange;
286 for (
auto &[k, v] : rhs->valMap) {
287 auto it = valMap.find(k);
288 if (it == valMap.end() || it->second != v) {
290 res |= ChangeResult::Change;
293 for (
auto &v : rhs->constraints) {
294 if (!constraints.contains(v)) {
295 constraints.insert(v);
296 res |= ChangeResult::Change;
299 for (
auto &[e, i] : rhs->intervals) {
300 auto it = intervals.find(e);
301 if (it == intervals.end() || it->second != i) {
303 res |= ChangeResult::Change;
310 os <<
"IntervalAnalysisLattice { ";
311 for (
auto &[ref, val] : valMap) {
312 os <<
"\n (valMap) " << ref <<
" := " << val;
314 for (
auto &[expr, interval] : intervals) {
315 os <<
"\n (intervals) ";
321 os <<
" in " << interval;
323 if (!valMap.empty()) {
330 auto it = valMap.find(v);
331 if (it == valMap.end()) {
337FailureOr<IntervalAnalysisLattice::LatticeValue>
339 auto it = fieldMap.find(v);
340 if (it == fieldMap.end()) {
343 auto fit = it->second.find(f);
344 if (fit == it->second.end()) {
351 if (valMap[v] == val) {
352 return ChangeResult::NoChange;
355 ExpressionValue e = val.foldToScalar();
356 intervals[e.getExpr()] = e.getInterval();
357 return ChangeResult::Change;
362 if (valMap[v] == val) {
363 return ChangeResult::NoChange;
366 intervals[e.getExpr()] = e.getInterval();
367 return ChangeResult::Change;
372 if (fieldMap[v][f] == val) {
373 return ChangeResult::NoChange;
375 fieldMap[v][f] = val;
376 intervals[e.getExpr()] = e.getInterval();
377 return ChangeResult::Change;
381 if (!constraints.contains(e)) {
382 constraints.insert(e);
383 return ChangeResult::Change;
385 return ChangeResult::NoChange;
389 auto it = intervals.find(expr);
390 if (it != intervals.end()) {
397 auto it = intervals.find(expr);
398 if (it != intervals.end() && it->second == i) {
399 return ChangeResult::NoChange;
402 return ChangeResult::Change;
417 if (action == dataflow::CallControlFlowAction::EnterCallee) {
421 setToEntryState(after);
426 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
430 ensure(beforeCall,
"could not get prior lattice");
433 propagateIfChanged(after, after->
join(*beforeCall));
439 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
450IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) {
451 ProgramPoint *pp = _dataflowSolver.getProgramPointAfter(baseOp);
452 auto defaultSourceRefLattice = _dataflowSolver.lookupState<
SourceRefLattice>(pp);
453 ensure(defaultSourceRefLattice,
"failed to get lattice");
454 if (Operation *defOp = val.getDefiningOp()) {
455 ProgramPoint *defPoint = _dataflowSolver.getProgramPointAfter(defOp);
456 auto sourceRefLattice = _dataflowSolver.lookupState<
SourceRefLattice>(defPoint);
457 ensure(sourceRefLattice,
"failed to get SourceRefLattice for value");
458 return sourceRefLattice;
460 return defaultSourceRefLattice;
471 ChangeResult changed = ChangeResult::NoChange;
475 for (BlockArgument blockArg : fn.getArguments()) {
476 auto blockArgLookupRes = before.
getValue(blockArg);
477 if (succeeded(blockArgLookupRes)) {
478 changed |= after->
setValue(blockArg, *blockArgLookupRes);
482 auto getAfter = [&](Value val) {
483 if (Operation *defOp = val.getDefiningOp()) {
484 return getLattice(getProgramPointAfter(defOp));
485 }
else if (
auto blockArg = dyn_cast<BlockArgument>(val)) {
486 Operation *blockEntry = &blockArg.getOwner()->front();
487 return getLattice(getProgramPointBefore(blockEntry));
492 llvm::SmallVector<LatticeValue> operandVals;
493 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
494 for (OpOperand &operand : op->getOpOperands()) {
495 Value val = operand.get();
500 operandRefs.push_back(std::nullopt);
503 Lattice *valLattice = getAfter(val);
504 auto priorState = valLattice->
getValue(val);
505 if (succeeded(priorState) && priorState->getScalarValue().getExpr() !=
nullptr) {
506 operandVals.push_back(*priorState);
507 changed |= after->
setValue(val, *priorState);
514 Type valTy = val.getType();
515 if (llvm::isa<ArrayType, StructType>(valTy) && !
isSignalType(valTy)) {
517 operandVals.push_back(empty);
518 changed |= after->
setValue(val, empty);
522 ensure(refSet.
isScalar(),
"should have ruled out array values already");
530 "state of ", val,
" is empty; defining operation is unsupported by SourceRef analysis"
533 propagateIfChanged(after, changed);
539 debug::Appender(warning) <<
"operand " << val <<
" is not a single value " << refSet
540 <<
", overapproximating";
541 op->emitWarning(warning).report();
545 changed |= after->
setValue(val, anyVal);
546 operandVals.emplace_back(anyVal);
550 if (succeeded(priorState)) {
551 exprVal = exprVal.
withInterval(priorState->getScalarValue().getInterval());
553 changed |= after->
setValue(val, exprVal);
554 operandVals.emplace_back(exprVal);
561 ensure(succeeded(res),
"expected precondition is that value is set");
562 (void)valLattice->
setValue(val, *res);
567 llvm::DynamicAPInt constVal = getConst(op);
568 llvm::SMTExprRef expr = createConstBitvectorExpr(constVal);
570 changed |= after->
setValue(op->getResult(0), latticeVal);
571 }
else if (isArithmeticOp(op)) {
572 ensure(operandVals.size() <= 2,
"arithmetic op with the wrong number of operands");
574 if (operandVals.size() == 2) {
575 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
577 result = performUnaryArithmetic(op, operandVals[0]);
580 changed |= after->
setValue(op->getResult(0), result);
581 }
else if (
EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
582 ensure(operandVals.size() == 2,
"constraint op with the wrong number of operands");
583 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
589 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
590 if (succeeded(res)) {
591 for (Value signalVal : res->first) {
592 changed |= applyInterval(emitEq, after, getAfter(signalVal), signalVal, res->second);
600 changed |= applyInterval(emitEq, after, getAfter(lhsVal), lhsVal, constrainInterval);
601 changed |= applyInterval(emitEq, after, getAfter(rhsVal), rhsVal, constrainInterval);
603 }
else if (
AssertOp assertOp = llvm::dyn_cast<AssertOp>(op)) {
604 ensure(operandVals.size() == 1,
"assert op with the wrong number of operands");
607 changed |= applyInterval(
608 assertOp, after, after, assertOp.getCondition(),
612 auto assertExpr = operandVals[0].getScalarValue();
614 }
else if (
auto readf = llvm::dyn_cast<FieldReadOp>(op)) {
615 Value
cmp = readf.getComponent();
619 changed |= after->
setValue(readf.getVal(), operandVals[0].getScalarValue());
621 auto storedVal = getAfter(
cmp)->getValue(
cmp, readf.getFieldNameAttr().getAttr());
622 if (succeeded(storedVal)) {
624 changed |= after->
setValue(readf.getVal(), storedVal->getScalarValue());
625 }
else if (operandRefs[0].has_value()) {
627 auto fieldDefRes = readf.getFieldDefOp(tables);
628 if (succeeded(fieldDefRes)) {
631 changed |= after->
setValue(readf.getVal(), exprVal);
635 }
else if (
auto writef = llvm::dyn_cast<FieldWriteOp>(op)) {
638 auto cmp = writef.getComponent();
639 changed |= after->
setValue(
cmp, writef.getFieldNameAttr().getAttr(), writeVal);
643 auto fieldDefRes = writef.getFieldDefOp(tables);
644 if (succeeded(fieldDefRes)) {
651 }
else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
658 expr =
boolToFelt(smtSolver, expr, field.get().bitWidth());
660 changed |= after->
setValue(op->getResult(0), expr);
661 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
664 Operation *parent = op->getParentOp();
665 ensure(parent,
"yield operation must have parent operation");
666 auto postYieldLattice =
getLattice(getProgramPointAfter(parent));
667 ensure(postYieldLattice,
"could not fetch post-yield lattice");
669 for (
unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
670 Value parentRes = parent->getResult(idx);
673 auto exprValRes = postYieldLattice->getValue(parentRes);
675 if (succeeded(exprValRes)) {
682 changed |= after->
setValue(parentRes, newResVal);
685 propagateIfChanged(postYieldLattice, postYieldLattice->join(*after));
694 && !isDefinitionOp(op)
696 && !llvm::isa<CreateStructOp>(op)
698 op->emitWarning(
"unhandled operation, analysis may be incomplete").report();
701 propagateIfChanged(after, changed);
706 auto it = refSymbols.find(r);
707 if (it != refSymbols.end()) {
710 llvm::SMTExprRef sym = createFeltSymbol(r);
715llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const SourceRef &r)
const {
719llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v)
const {
723llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const char *
name)
const {
724 return field.get().createSymbol(smtSolver,
name);
727llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op)
const {
728 ensure(isConstOp(op),
"op is not a const op");
730 llvm::DynamicAPInt fieldConst =
731 TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
732 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
733 llvm::APSInt constOpVal(feltConst.getValue());
734 return field.get().reduce(constOpVal);
736 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
737 return DynamicAPInt(indexConst.value());
739 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
740 return DynamicAPInt(intConst.value());
741 }).Default([](Operation *illegalOp) {
743 debug::Appender(err) <<
"unhandled getConst case: " << *illegalOp;
744 llvm::report_fatal_error(Twine(err));
745 return llvm::DynamicAPInt();
751 Operation *op,
const LatticeValue &a,
const LatticeValue &b
753 ensure(isArithmeticOp(op),
"is not arithmetic op");
755 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
756 ensure(lhs.getExpr(),
"cannot perform arithmetic over null lhs smt expr");
757 ensure(rhs.getExpr(),
"cannot perform arithmetic over null rhs smt expr");
759 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
760 .Case<AddFeltOp>([&](
auto _) {
return add(smtSolver, lhs, rhs); })
761 .Case<SubFeltOp>([&](
auto _) {
return sub(smtSolver, lhs, rhs); })
762 .Case<MulFeltOp>([&](
auto _) {
return mul(smtSolver, lhs, rhs); })
763 .Case<DivFeltOp>([&](
auto divOp) {
return div(smtSolver, divOp, lhs, rhs); })
764 .Case<ModFeltOp>([&](
auto _) {
return mod(smtSolver, lhs, rhs); })
765 .Case<AndFeltOp>([&](
auto _) {
return bitAnd(smtSolver, lhs, rhs); })
766 .Case<ShlFeltOp>([&](
auto _) {
return shiftLeft(smtSolver, lhs, rhs); })
767 .Case<ShrFeltOp>([&](
auto _) {
return shiftRight(smtSolver, lhs, rhs); })
768 .Case<CmpOp>([&](
auto cmpOp) {
return cmp(smtSolver, cmpOp, lhs, rhs); })
769 .Case<AndBoolOp>([&](
auto _) {
return boolAnd(smtSolver, lhs, rhs); })
770 .Case<OrBoolOp>([&](
auto _) {
return boolOr(smtSolver, lhs, rhs); })
771 .Case<XorBoolOp>([&](
auto _) {
772 return boolXor(smtSolver, lhs, rhs);
773 }).Default([&](
auto *unsupported) {
776 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
782 ensure(res.getExpr(),
"arithmetic produced null smt expr");
787IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op,
const LatticeValue &a) {
788 ensure(isArithmeticOp(op),
"is not arithmetic op");
790 auto val = a.getScalarValue();
791 ensure(val.getExpr(),
"cannot perform arithmetic over null smt expr");
793 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
794 .Case<NegFeltOp>([&](
auto _) {
return neg(smtSolver, val); })
795 .Case<NotFeltOp>([&](
auto _) {
return notOp(smtSolver, val); })
796 .Case<NotBoolOp>([&](
auto _) {
return boolNot(smtSolver, val); })
798 .Case<InvFeltOp>([&](
auto inv) {
800 }).Default([&](
auto *unsupported) {
803 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
809 ensure(res.getExpr(),
"arithmetic produced null smt expr");
813ChangeResult IntervalDataFlowAnalysis::applyInterval(
814 Operation *originalOp, Lattice *originalLattice, Lattice *after, Value val,
Interval newInterval
816 auto latValRes = after->getValue(val);
817 if (failed(latValRes)) {
819 return ChangeResult::NoChange;
821 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
822 propagateIfChanged(after, after->setValue(val, newLatticeVal));
823 ChangeResult res = originalLattice->setValue(val, newLatticeVal);
826 Lattice *valLattice =
nullptr;
827 if (Operation *valOp = val.getDefiningOp()) {
828 valLattice =
getLattice(getProgramPointAfter(valOp));
829 }
else if (
auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
830 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
831 Operation *blockEntry = &blockArg.getOwner()->front();
834 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
835 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
836 auto structOp = fnOp->getParentOfType<StructDefOp>();
837 FuncDefOp computeFn = structOp.getComputeFuncOp();
838 Operation *computeEntry = &computeFn.getRegion().front().front();
839 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
840 Lattice *computeEntryLattice =
getLattice(getProgramPointBefore(computeEntry));
842 SourceRef ref(computeArg);
844 ChangeResult computeRes = computeEntryLattice->setValue(computeArg, newArgVal);
845 propagateIfChanged(computeEntryLattice, computeRes);
848 valLattice =
getLattice(getProgramPointBefore(blockEntry));
853 ensure(valLattice,
"val should have a lattice");
854 auto setNewVal = [&valLattice, &val, &newLatticeVal, &res,
this]() {
855 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
860 Operation *definingOp = val.getDefiningOp();
864 Lattice *definingOpLattice =
getLattice(getProgramPointAfter(definingOp));
865 auto getOperandLattice = [&](Value operand) {
866 if (Operation *defOp = operand.getDefiningOp()) {
867 return getLattice(getProgramPointAfter(defOp));
868 }
else if (
auto blockArg = dyn_cast<BlockArgument>(operand)) {
869 Operation *blockEntry = &blockArg.getOwner()->front();
870 return getLattice(getProgramPointBefore(blockEntry));
872 return definingOpLattice;
874 auto getOperandLatticeVal = [&](Value operand) {
875 return getOperandLattice(operand)->getValue(operand);
878 const Field &f = field.get();
886 auto cmpCase = [&](CmpOp cmpOp) {
890 newInterval.isBoolean(),
891 "new interval for CmpOp outside of allowed boolean range or is empty"
893 if (!newInterval.isDegenerate()) {
895 return ChangeResult::NoChange;
898 bool cmpTrue = newInterval.rhs() == f.one();
900 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
901 auto lhsLatValRes = getOperandLatticeVal(lhs), rhsLatValRes = getOperandLatticeVal(rhs);
902 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
903 return ChangeResult::NoChange;
905 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
906 rhsExpr = rhsLatValRes->getScalarValue();
908 Interval newLhsInterval, newRhsInterval;
909 const Interval &lhsInterval = lhsExpr.getInterval();
910 const Interval &rhsInterval = rhsExpr.getInterval();
914 auto eqCase = [&]() {
915 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
916 (pred == FeltCmpPredicate::NE && !cmpTrue);
918 auto neCase = [&]() {
919 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
920 (pred == FeltCmpPredicate::EQ && !cmpTrue);
922 auto ltCase = [&]() {
923 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
924 (pred == FeltCmpPredicate::GE && !cmpTrue);
926 auto leCase = [&]() {
927 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
928 (pred == FeltCmpPredicate::GT && !cmpTrue);
930 auto gtCase = [&]() {
931 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
932 (pred == FeltCmpPredicate::LE && !cmpTrue);
934 auto geCase = [&]() {
935 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
936 (pred == FeltCmpPredicate::LT && !cmpTrue);
941 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
942 }
else if (neCase()) {
943 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
947 }
else if (lhsInterval.isDegenerate()) {
949 newLhsInterval = lhsInterval;
950 newRhsInterval = rhsInterval.difference(lhsInterval);
951 }
else if (rhsInterval.isDegenerate()) {
953 newLhsInterval = lhsInterval.difference(rhsInterval);
954 newRhsInterval = rhsInterval;
957 newLhsInterval = lhsInterval;
958 newRhsInterval = rhsInterval;
960 }
else if (ltCase()) {
961 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
962 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
963 }
else if (leCase()) {
964 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
965 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
966 }
else if (gtCase()) {
967 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
968 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
969 }
else if (geCase()) {
970 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
971 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
973 cmpOp->emitWarning(
"unhandled cmp predicate").report();
974 return ChangeResult::NoChange;
978 return applyInterval(originalOp, originalLattice, getOperandLattice(lhs), lhs, newLhsInterval) |
979 applyInterval(originalOp, originalLattice, getOperandLattice(rhs), rhs, newRhsInterval);
984 auto mulCase = [&](MulFeltOp mulOp) {
986 if (newInterval.intersect(zeroInt).isNotEmpty()) {
988 return ChangeResult::NoChange;
991 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
992 auto lhsLatValRes = getOperandLatticeVal(lhs), rhsLatValRes = getOperandLatticeVal(rhs);
993 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
994 return ChangeResult::NoChange;
996 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
997 rhsExpr = rhsLatValRes->getScalarValue();
998 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
999 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1000 return applyInterval(originalOp, originalLattice, getOperandLattice(lhs), lhs, newLhsInterval) |
1001 applyInterval(originalOp, originalLattice, getOperandLattice(rhs), rhs, newRhsInterval);
1007 auto readfCase = [&](FieldReadOp readfOp) {
1008 Value comp = readfOp.getComponent();
1010 return applyInterval(originalOp, originalLattice, getOperandLattice(comp), comp, newInterval);
1012 return ChangeResult::NoChange;
1016 auto castCase = [&](Operation *op) {
1017 Value operand = op->getOperand(0);
1018 return applyInterval(
1019 originalOp, originalLattice, getOperandLattice(operand), operand, newInterval
1027 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
1028 .Case<CmpOp>([&](
auto op) {
return cmpCase(op); })
1029 .Case<MulFeltOp>([&](
auto op) {
return mulCase(op); })
1030 .Case<FieldReadOp>([&](
auto op){
return readfCase(op); })
1031 .Case<IntToFeltOp, FeltToIndexOp>([&](
auto op) {
return castCase(op); })
1032 .Default([&](Operation *) {
return ChangeResult::NoChange; });
1039FailureOr<std::pair<DenseSet<Value>,
Interval>>
1040IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
1041 auto isZeroConst = [
this](Value v) {
1042 Operation *op = v.getDefiningOp();
1046 if (!isConstOp(op)) {
1049 return getConst(op) == field.get().zero();
1051 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1052 Value exprTree =
nullptr;
1053 if (lhsIsZero && !rhsIsZero) {
1055 }
else if (!lhsIsZero && rhsIsZero) {
1062 std::optional<SourceRef> signalRef = std::nullopt;
1063 DenseSet<Value> signalVals;
1064 SmallVector<DynamicAPInt> consts;
1065 SmallVector<Value> frontier {exprTree};
1066 while (!frontier.empty()) {
1067 Value v = frontier.back();
1068 frontier.pop_back();
1069 Operation *op = v.getDefiningOp();
1073 auto handleRefValue = [
this, &baseOp, &signalRef, &signalVal, &signalVals]() {
1074 SourceRefLatticeValue refSet =
1075 getSourceRefLattice(baseOp, signalVal)->getOrDefault(signalVal);
1076 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1079 SourceRef r = refSet.getSingleValue();
1080 if (signalRef.has_value() && signalRef.value() != r) {
1082 }
else if (!signalRef.has_value()) {
1085 signalVals.insert(signalVal);
1090 if (op && matchPattern(op, subPattern)) {
1091 if (failed(handleRefValue())) {
1094 auto constInt = APSInt(c.getValue());
1095 consts.push_back(field.get().reduce(constInt));
1097 }
else if (
m_RefValue(&signalVal).match(v)) {
1098 if (failed(handleRefValue())) {
1101 consts.push_back(field.get().zero());
1107 if (op && matchPattern(op, mulPattern)) {
1108 frontier.push_back(a);
1109 frontier.push_back(b);
1118 std::sort(consts.begin(), consts.end());
1119 Interval iv =
Interval::TypeA(field.get(), consts.front(), consts.back());
1120 return std::make_pair(std::move(signalVals), iv);
1125static void getReversedOps(Region *r, llvm::SmallVector<Operation *> &opList) {
1126 for (Block &b : llvm::reverse(*r)) {
1127 for (Operation &op : llvm::reverse(b)) {
1128 for (Region &nested : llvm::reverse(op.getRegions())) {
1129 getReversedOps(&nested, opList);
1131 opList.push_back(&op);
1140 auto computeIntervalsImpl = [&solver, &ctx,
this](
1141 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &fieldRanges,
1142 llvm::SetVector<ExpressionValue> &solverConstraints
1153 if (!ref.isScalar() && !ref.isSignal()) {
1157 if (
auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1160 searchSet.insert(ref);
1164 llvm::SmallVector<Operation *> opList;
1165 getReversedOps(&fn.getBody(), opList);
1168 opList.push_back(fn);
1170 for (Operation *op : opList) {
1171 ProgramPoint *pp = solver.getProgramPointAfter(op);
1174 solverConstraints.insert(c.
begin(), c.
end());
1177 for (
const auto &ref : searchSet) {
1180 if (succeeded(intervalRes)) {
1181 fieldRanges[ref] = *intervalRes;
1183 newSearchSet.insert(ref);
1186 searchSet = newSearchSet;
1190 for (
const auto &ref : searchSet) {
1195 llvm::sort(fieldRanges, [](
auto a,
auto b) {
return std::get<0>(a) < std::get<0>(b); });
1198 computeIntervalsImpl(structDef.getComputeFuncOp(), computeFieldRanges, computeSolverConstraints);
1199 computeIntervalsImpl(
1200 structDef.getConstrainFuncOp(), constrainFieldRanges, constrainSolverConstraints
1207 auto writeIntervals =
1208 [&os, &withConstraints](
1209 const char *fnName,
const llvm::MapVector<SourceRef, Interval> &fieldRanges,
1210 const llvm::SetVector<ExpressionValue> &solverConstraints,
bool printName
1215 os.indent(indent) << fnName <<
" {";
1219 if (fieldRanges.empty()) {
1224 for (
auto &[ref, interval] : fieldRanges) {
1226 os.indent(indent) << ref <<
" in " << interval;
1229 if (withConstraints) {
1231 os.indent(indent) <<
"Solver Constraints { ";
1232 if (solverConstraints.empty()) {
1235 for (
const auto &e : solverConstraints) {
1237 os.indent(indent + 4);
1238 e.getExpr()->print(os);
1241 os.indent(indent) <<
'}';
1247 os.indent(indent - 4) <<
'}';
1251 os <<
"StructIntervals { ";
1252 if (constrainFieldRanges.empty() && (!printCompute || computeFieldRanges.empty())) {
1258 writeIntervals(
FUNC_NAME_COMPUTE, computeFieldRanges, computeSolverConstraints, printCompute);
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
bool isBoolSort(llvm::SMTSolverRef solver) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
unsigned bitWidth() const
Maps mlir::Values to LatticeValues.
IntervalAnalysisLatticeValue LatticeValue
ValueMap::iterator begin()
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult setValue(mlir::Value v, const LatticeValue &val)
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i)
mlir::LogicalResult visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
Intervals over a finite field.
static Interval True(const Field &f)
Interval intersect(const Interval &rhs) const
Intersect.
static Interval Boolean(const Field &f)
bool isDegenerate() const
static Interval False(const Field &f)
static Interval Degenerate(const Field &f, const llvm::DynamicAPInt &val)
Interval join(const Interval &rhs) const
Union.
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
A lattice for use in dense analysis.
A reference to a "source", which is the base value from which other SSA values are derived.
SourceRef createChild(SourceRefIndex r) const
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
::llzk::boolean::FeltCmpPredicate getPredicate()
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
bool isSingleValue() const
const ScalarTy & getScalarValue() const
IntervalAnalysisLattice * getLattice(mlir::LatticeAnchor anchor) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
llvm::SMTExprRef getSymbol(const SourceRef &r) const