16#include <mlir/Dialect/SCF/IR/SCF.h>
18#include <llvm/ADT/TypeSwitch.h>
35 if (expr ==
nullptr && rhs.expr ==
nullptr) {
38 if (expr ==
nullptr || rhs.expr ==
nullptr) {
41 return i == rhs.i && *expr == *rhs.expr;
46 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
47 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
48 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.
getExpr(), one, zero);
55 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
62 res.i = lhs.i + rhs.i;
63 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
70 res.i = lhs.i - rhs.i;
71 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
78 res.i = lhs.i * rhs.i;
79 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
87 auto divRes = lhs.i / rhs.i;
90 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
91 " Range of division result will be treated as unbounded."
98 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
105 res.i = lhs.i % rhs.i;
106 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
113 res.i = lhs.i & rhs.i;
114 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
121 res.i = lhs.i << rhs.i;
122 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
129 res.i = lhs.i >> rhs.i;
130 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
141 case FeltCmpPredicate::EQ:
142 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
149 case FeltCmpPredicate::NE:
150 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
157 case FeltCmpPredicate::LT:
158 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
166 case FeltCmpPredicate::LE:
167 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
175 case FeltCmpPredicate::GT:
176 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
184 case FeltCmpPredicate::GE:
185 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
201 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
208 res.i =
boolOr(lhs.i, rhs.i);
209 res.expr = solver->mkOr(lhs.expr, rhs.expr);
218 res.expr = solver->mkAnd(
219 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
229 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
230 .Case<
OrFeltOp>([&](
auto) {
return solver->mkBVOr(lhs.expr, rhs.expr); })
231 .Case<XorFeltOp>([&](
auto) {
232 return solver->mkBVXor(lhs.expr, rhs.expr);
233 }).Default([&](
auto *unsupported) {
234 llvm::report_fatal_error(
235 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
246 res.expr = solver->mkBVNeg(val.expr);
253 res.expr = solver->mkBVNot(val.expr);
260 res.expr = solver->mkNot(val.expr);
269 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
275 llvm::SMTExprRef invSym = field.
createSymbol(solver, symName.c_str());
276 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.
bitWidth());
278 llvm::SMTExprRef mult = solver->mkBVMul(val.
getExpr(), invSym);
279 llvm::SMTExprRef
mod = solver->mkBVURem(mult, prime);
280 llvm::SMTExprRef constraint = solver->mkEqual(
mod, one);
281 solver->addConstraint(constraint);
283 }).Default([](Operation *unsupported) {
284 llvm::report_fatal_error(
285 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
297 os <<
"<null expression>";
300 os <<
" ( interval: " << i <<
" )";
308 llvm::report_fatal_error(
"invalid join lattice type");
310 ChangeResult res = val.update(rhs->getValue());
311 for (
auto &v : rhs->constraints) {
312 if (!constraints.contains(v)) {
313 constraints.insert(v);
314 res |= ChangeResult::Change;
323 llvm::report_fatal_error(
"invalid join lattice type");
330 for (
auto &v : rhs->constraints) {
331 if (!constraints.contains(v)) {
332 constraints.insert(v);
333 res |= ChangeResult::Change;
340 os <<
"IntervalAnalysisLattice { " << val <<
" }";
345 return ChangeResult::NoChange;
348 return ChangeResult::Change;
357 if (!constraints.contains(e)) {
358 constraints.insert(e);
359 return ChangeResult::Change;
361 return ChangeResult::NoChange;
367IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) {
368 ProgramPoint *pp = _dataflowSolver.getProgramPointAfter(baseOp);
369 auto defaultSourceRefLattice = _dataflowSolver.lookupState<
SourceRefLattice>(pp);
370 ensure(defaultSourceRefLattice,
"failed to get lattice");
371 if (Operation *defOp = val.getDefiningOp()) {
372 ProgramPoint *defPoint = _dataflowSolver.getProgramPointAfter(defOp);
373 auto sourceRefLattice = _dataflowSolver.lookupState<
SourceRefLattice>(defPoint);
374 ensure(sourceRefLattice,
"failed to get SourceRefLattice for value");
375 return sourceRefLattice;
377 return defaultSourceRefLattice;
381 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
390 if (operands.empty() && results.empty()) {
395 llvm::SmallVector<LatticeValue> operandVals;
396 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
397 for (
unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
398 Value val = op->getOperand(opNum);
403 operandRefs.push_back(std::nullopt);
406 auto priorState = operands[opNum]->
getValue();
407 if (priorState.getScalarValue().getExpr() !=
nullptr) {
408 operandVals.push_back(priorState);
415 Type valTy = val.getType();
416 if (llvm::isa<ArrayType, StructType>(valTy) && !
isSignalType(valTy)) {
418 operandVals.emplace_back(anyVal);
422 ensure(refSet.
isScalar(),
"should have ruled out array values already");
430 "state of ", val,
" is empty; defining operation is unsupported by SourceRef analysis"
438 debug::Appender(warning) <<
"operand " << val <<
" is not a single value " << refSet
439 <<
", overapproximating";
440 op->emitWarning(warning).report();
444 operandVals.emplace_back(anyVal);
448 if (
auto it = fieldWriteResults.find(ref); it != fieldWriteResults.end()) {
449 operandVals.emplace_back(it->second);
452 operandVals.emplace_back(exprVal);
460 (void)operandLattice->
setValue(operandVals[opNum]);
465 llvm::DynamicAPInt constVal = getConst(op);
466 llvm::SMTExprRef expr = createConstBitvectorExpr(constVal);
468 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
469 }
else if (isArithmeticOp(op)) {
471 if (operands.size() == 2) {
472 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
474 result = performUnaryArithmetic(op, operandVals[0]);
477 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
481 propagateIfChanged(results[0], results[0]->setValue(result));
482 }
else if (
EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
483 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
489 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
490 if (succeeded(res)) {
491 for (Value signalVal : res->first) {
492 applyInterval(emitEq, signalVal, res->second);
500 applyInterval(emitEq, lhsVal, constrainInterval);
501 applyInterval(emitEq, rhsVal, constrainInterval);
502 }
else if (
auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
505 Value cond = assertOp.getCondition();
508 auto assertExpr = operandVals[0].getScalarValue();
511 }
else if (
auto readf = llvm::dyn_cast<FieldReadOp>(op)) {
512 Value
cmp = readf.getComponent();
516 propagateIfChanged(results[0], results[0]->setValue(operandVals[0]));
518 }
else if (
auto writef = llvm::dyn_cast<FieldWriteOp>(op)) {
521 auto cmp = writef.getComponent();
525 auto fieldDefRes = writef.getFieldDefOp(tables);
526 if (succeeded(fieldDefRes)) {
532 if (
auto it = fieldWriteResults.find(fieldRef); it != fieldWriteResults.end()) {
535 fieldWriteResults[fieldRef] = old.
withInterval(combinedWrite);
537 fieldWriteResults[fieldRef] = written;
541 for (Lattice *readerLattice : fieldReadResults[fieldRef]) {
545 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
549 }
else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
556 expr =
boolToFelt(smtSolver, expr, field.get().bitWidth());
558 propagateIfChanged(results[0], results[0]->setValue(expr));
559 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
562 Operation *parent = op->getParentOp();
563 ensure(parent,
"yield operation must have parent operation");
565 for (
unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
566 Value parentRes = parent->getResult(idx);
572 if (exprVal.
getExpr() !=
nullptr) {
577 propagateIfChanged(resLattice, resLattice->
setValue(newResVal));
587 && !isDefinitionOp(op)
589 && !llvm::isa<CreateStructOp>(op)
591 op->emitWarning(
"unhandled operation, analysis may be incomplete").report();
598 auto it = refSymbols.find(r);
599 if (it != refSymbols.end()) {
602 llvm::SMTExprRef sym = createFeltSymbol(r);
607llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const SourceRef &r)
const {
611llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v)
const {
615llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const char *name)
const {
616 return field.get().createSymbol(smtSolver, name);
619llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op)
const {
620 ensure(isConstOp(op),
"op is not a const op");
622 llvm::DynamicAPInt fieldConst =
623 TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
624 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
625 llvm::APSInt constOpVal(feltConst.getValue());
626 return field.get().reduce(constOpVal);
628 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
629 return DynamicAPInt(indexConst.value());
631 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
632 return DynamicAPInt(intConst.value());
633 }).Default([](Operation *illegalOp) {
635 debug::Appender(err) <<
"unhandled getConst case: " << *illegalOp;
636 llvm::report_fatal_error(Twine(err));
637 return llvm::DynamicAPInt();
643 Operation *op,
const LatticeValue &a,
const LatticeValue &b
645 ensure(isArithmeticOp(op),
"is not arithmetic op");
647 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
648 ensure(lhs.getExpr(),
"cannot perform arithmetic over null lhs smt expr");
649 ensure(rhs.getExpr(),
"cannot perform arithmetic over null rhs smt expr");
651 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
652 .Case<AddFeltOp>([&](
auto _) {
return add(smtSolver, lhs, rhs); })
653 .Case<SubFeltOp>([&](
auto _) {
return sub(smtSolver, lhs, rhs); })
654 .Case<MulFeltOp>([&](
auto _) {
return mul(smtSolver, lhs, rhs); })
655 .Case<DivFeltOp>([&](
auto divOp) {
return div(smtSolver, divOp, lhs, rhs); })
656 .Case<ModFeltOp>([&](
auto _) {
return mod(smtSolver, lhs, rhs); })
657 .Case<AndFeltOp>([&](
auto _) {
return bitAnd(smtSolver, lhs, rhs); })
658 .Case<ShlFeltOp>([&](
auto _) {
return shiftLeft(smtSolver, lhs, rhs); })
659 .Case<ShrFeltOp>([&](
auto _) {
return shiftRight(smtSolver, lhs, rhs); })
660 .Case<CmpOp>([&](
auto cmpOp) {
return cmp(smtSolver, cmpOp, lhs, rhs); })
661 .Case<AndBoolOp>([&](
auto _) {
return boolAnd(smtSolver, lhs, rhs); })
662 .Case<OrBoolOp>([&](
auto _) {
return boolOr(smtSolver, lhs, rhs); })
663 .Case<XorBoolOp>([&](
auto _) {
664 return boolXor(smtSolver, lhs, rhs);
665 }).Default([&](
auto *unsupported) {
668 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
674 ensure(res.getExpr(),
"arithmetic produced null smt expr");
679IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op,
const LatticeValue &a) {
680 ensure(isArithmeticOp(op),
"is not arithmetic op");
682 auto val = a.getScalarValue();
683 ensure(val.getExpr(),
"cannot perform arithmetic over null smt expr");
685 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
686 .Case<NegFeltOp>([&](
auto _) {
return neg(smtSolver, val); })
687 .Case<NotFeltOp>([&](
auto _) {
return notOp(smtSolver, val); })
688 .Case<NotBoolOp>([&](
auto _) {
return boolNot(smtSolver, val); })
690 .Case<InvFeltOp>([&](
auto inv) {
692 }).Default([&](
auto *unsupported) {
695 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
701 ensure(res.getExpr(),
"arithmetic produced null smt expr");
705void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val,
Interval newInterval) {
707 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
710 ExpressionValue newLatticeVal = oldLatticeVal.withInterval(
intersection);
711 ChangeResult changed = valLattice->setValue(newLatticeVal);
713 if (
auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
714 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
717 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
718 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
719 auto structOp = fnOp->getParentOfType<StructDefOp>();
720 FuncDefOp computeFn = structOp.getComputeFuncOp();
721 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
724 SourceRef ref(computeArg);
726 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
731 Operation *definingOp = val.getDefiningOp();
733 propagateIfChanged(valLattice, changed);
737 const Field &f = field.get();
745 auto cmpCase = [&](CmpOp cmpOp) {
751 newInterval.isBoolean() || newInterval.isEmpty(),
752 "new interval for CmpOp is not boolean or empty"
754 if (!newInterval.isDegenerate()) {
759 bool cmpTrue = newInterval.rhs() == f.one();
761 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
763 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
764 rhsExpr = rhsLat->getValue().getScalarValue();
766 Interval newLhsInterval, newRhsInterval;
767 const Interval &lhsInterval = lhsExpr.getInterval();
768 const Interval &rhsInterval = rhsExpr.getInterval();
772 auto eqCase = [&]() {
773 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
774 (pred == FeltCmpPredicate::NE && !cmpTrue);
776 auto neCase = [&]() {
777 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
778 (pred == FeltCmpPredicate::EQ && !cmpTrue);
780 auto ltCase = [&]() {
781 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
782 (pred == FeltCmpPredicate::GE && !cmpTrue);
784 auto leCase = [&]() {
785 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
786 (pred == FeltCmpPredicate::GT && !cmpTrue);
788 auto gtCase = [&]() {
789 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
790 (pred == FeltCmpPredicate::LE && !cmpTrue);
792 auto geCase = [&]() {
793 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
794 (pred == FeltCmpPredicate::LT && !cmpTrue);
799 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
800 }
else if (neCase()) {
801 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
805 }
else if (lhsInterval.isDegenerate()) {
807 newLhsInterval = lhsInterval;
808 newRhsInterval = rhsInterval.difference(lhsInterval);
809 }
else if (rhsInterval.isDegenerate()) {
811 newLhsInterval = lhsInterval.difference(rhsInterval);
812 newRhsInterval = rhsInterval;
815 newLhsInterval = lhsInterval;
816 newRhsInterval = rhsInterval;
818 }
else if (ltCase()) {
819 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
820 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
821 }
else if (leCase()) {
822 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
823 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
824 }
else if (gtCase()) {
825 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
826 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
827 }
else if (geCase()) {
828 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
829 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
831 cmpOp->emitWarning(
"unhandled cmp predicate").report();
836 applyInterval(cmpOp, lhs, newLhsInterval);
837 applyInterval(cmpOp, rhs, newRhsInterval);
845 auto mulCase = [&](MulFeltOp mulOp) {
847 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
849 APInt constVal = constOperand.getValue();
850 if (constVal.isZero()) {
855 applyInterval(mulOp, multiplicand, updatedInterval);
858 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
860 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
861 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
863 if (lhsConstOp && rhsConstOp) {
865 }
else if (lhsConstOp) {
866 constCase(lhsConstOp, rhs);
868 }
else if (rhsConstOp) {
869 constCase(rhsConstOp, lhs);
875 if (newInterval.intersect(zeroInt).isNotEmpty()) {
881 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
882 rhsExpr = rhsLat->getValue().getScalarValue();
883 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
884 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
885 applyInterval(mulOp, lhs, newLhsInterval);
886 applyInterval(mulOp, rhs, newRhsInterval);
889 auto addCase = [&](AddFeltOp addOp) {
890 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
892 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
893 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
895 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
897 Interval derivedLhsInt = newInterval - currRhsInt;
898 Interval derivedRhsInt = newInterval - currLhsInt;
900 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
901 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
903 applyInterval(addOp, lhs, finalLhsInt);
904 applyInterval(addOp, rhs, finalRhsInt);
907 auto subCase = [&](SubFeltOp subOp) {
908 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
910 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
911 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
913 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
915 Interval derivedLhsInt = newInterval + currRhsInt;
916 Interval derivedRhsInt = currLhsInt - newInterval;
918 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
919 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
921 applyInterval(subOp, lhs, finalLhsInt);
922 applyInterval(subOp, rhs, finalRhsInt);
925 auto readfCase = [&](FieldReadOp readfOp) {
926 const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val);
927 SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val);
929 if (sourceRefVal.isSingleValue()) {
930 const SourceRef &ref = sourceRefVal.getSingleValue();
931 fieldReadResults[ref].insert(valLattice);
934 for (Lattice *l : fieldReadResults[ref]) {
935 if (l != valLattice) {
936 propagateIfChanged(l, l->setValue(newLatticeVal));
944 Value comp = readfOp.getComponent();
946 applyInterval(readfOp, comp, newInterval);
950 auto readArrCase = [&](ReadArrayOp _) {
951 const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val);
952 SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val);
954 if (sourceRefVal.isSingleValue()) {
955 const SourceRef &ref = sourceRefVal.getSingleValue();
956 fieldReadResults[ref].insert(valLattice);
959 for (Lattice *l : fieldReadResults[ref]) {
960 if (l != valLattice) {
961 propagateIfChanged(l, l->setValue(newLatticeVal));
968 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
974 TypeSwitch<Operation *>(definingOp)
975 .Case<CmpOp>([&](
auto op) { cmpCase(op); })
976 .Case<AddFeltOp>([&](
auto op) {
return addCase(op); })
977 .Case<SubFeltOp>([&](
auto op) {
return subCase(op); })
978 .Case<MulFeltOp>([&](
auto op) { mulCase(op); })
979 .Case<FieldReadOp>([&](
auto op){ readfCase(op); })
980 .Case<ReadArrayOp>([&](
auto op){ readArrCase(op); })
981 .Case<IntToFeltOp, FeltToIndexOp>([&](
auto op) { castCase(op); })
982 .Default([&](Operation *) { });
986 propagateIfChanged(valLattice, changed);
989FailureOr<std::pair<DenseSet<Value>,
Interval>>
990IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
991 auto isZeroConst = [
this](Value v) {
992 Operation *op = v.getDefiningOp();
996 if (!isConstOp(op)) {
999 return getConst(op) == field.get().zero();
1001 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1002 Value exprTree =
nullptr;
1003 if (lhsIsZero && !rhsIsZero) {
1005 }
else if (!lhsIsZero && rhsIsZero) {
1012 std::optional<SourceRef> signalRef = std::nullopt;
1013 DenseSet<Value> signalVals;
1014 SmallVector<DynamicAPInt> consts;
1015 SmallVector<Value> frontier {exprTree};
1016 while (!frontier.empty()) {
1017 Value v = frontier.back();
1018 frontier.pop_back();
1019 Operation *op = v.getDefiningOp();
1023 auto handleRefValue = [
this, &baseOp, &signalRef, &signalVal, &signalVals]() {
1024 SourceRefLatticeValue refSet =
1025 getSourceRefLattice(baseOp, signalVal)->getOrDefault(signalVal);
1026 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1029 SourceRef r = refSet.getSingleValue();
1030 if (signalRef.has_value() && signalRef.value() != r) {
1032 }
else if (!signalRef.has_value()) {
1035 signalVals.insert(signalVal);
1040 if (op && matchPattern(op, subPattern)) {
1041 if (failed(handleRefValue())) {
1044 auto constInt = APSInt(c.getValue());
1045 consts.push_back(field.get().reduce(constInt));
1047 }
else if (
m_RefValue(&signalVal).match(v)) {
1048 if (failed(handleRefValue())) {
1051 consts.push_back(field.get().zero());
1057 if (op && matchPattern(op, mulPattern)) {
1058 frontier.push_back(a);
1059 frontier.push_back(b);
1068 std::sort(consts.begin(), consts.end());
1069 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1070 return std::make_pair(std::move(signalVals), iv);
1079 auto validSourceRefType = [](
const SourceRef &ref) {
1082 if (!ref.isScalar() && !ref.isSignal()) {
1086 if (
auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1092 auto computeIntervalsImpl = [&solver, &ctx, &validSourceRefType,
this](
1093 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &fieldRanges,
1094 llvm::SetVector<ExpressionValue> &solverConstraints
1105 if (validSourceRefType(ref)) {
1106 searchSet.insert(ref);
1111 for (BlockArgument arg : fn.getArguments()) {
1113 if (searchSet.erase(ref)) {
1117 if (!expr.getExpr()) {
1120 fieldRanges[ref] = expr.getInterval();
1121 assert(fieldRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1128 if (!lattices.empty() && searchSet.erase(ref)) {
1131 assert(fieldRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1136 if (searchSet.erase(ref)) {
1137 fieldRanges[ref] = val.getInterval();
1138 assert(fieldRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1143 for (
const auto &ref : searchSet) {
1148 llvm::sort(fieldRanges, [](
auto a,
auto b) {
return std::get<0>(a) < std::get<0>(b); });
1151 computeIntervalsImpl(structDef.getComputeFuncOp(), computeFieldRanges, computeSolverConstraints);
1152 computeIntervalsImpl(
1153 structDef.getConstrainFuncOp(), constrainFieldRanges, constrainSolverConstraints
1160 auto writeIntervals =
1161 [&os, &withConstraints](
1162 const char *fnName,
const llvm::MapVector<SourceRef, Interval> &fieldRanges,
1163 const llvm::SetVector<ExpressionValue> &solverConstraints,
bool printName
1168 os.indent(indent) << fnName <<
" {";
1172 if (fieldRanges.empty()) {
1177 for (
auto &[ref, interval] : fieldRanges) {
1179 os.indent(indent) << ref <<
" in " << interval;
1182 if (withConstraints) {
1184 os.indent(indent) <<
"Solver Constraints { ";
1185 if (solverConstraints.empty()) {
1188 for (
const auto &e : solverConstraints) {
1190 os.indent(indent + 4);
1191 e.getExpr()->print(os);
1194 os.indent(indent) <<
'}';
1200 os.indent(indent - 4) <<
'}';
1204 os <<
"StructIntervals { ";
1205 if (constrainFieldRanges.empty() && (!printCompute || computeFieldRanges.empty())) {
1211 writeIntervals(
FUNC_NAME_COMPUTE, computeFieldRanges, computeSolverConstraints, printCompute);
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
bool isBoolSort(llvm::SMTSolverRef solver) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
unsigned bitWidth() const
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getFieldReadResults() const
const llvm::DenseMap< SourceRef, ExpressionValue > & getFieldWriteResults() const
Intervals over a finite field.
static Interval True(const Field &f)
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
static Interval Entire(const Field &f)
bool isDegenerate() const
static Interval False(const Field &f)
Interval join(const Interval &rhs) const
Union.
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
A lattice for use in dense analysis.
A reference to a "source", which is the base value from which other SSA values are derived.
SourceRef createChild(SourceRefIndex r) const
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
::llzk::boolean::FeltCmpPredicate getPredicate()
std::variant< ScalarTy, ArrayTy > & getValue()
bool isSingleValue() const
const ScalarTy & getScalarValue() const
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA