15#include <mlir/Dialect/SCF/IR/SCF.h>
17#include <llvm/ADT/TypeSwitch.h>
34 if (expr ==
nullptr && rhs.expr ==
nullptr) {
37 if (expr ==
nullptr || rhs.expr ==
nullptr) {
40 return i == rhs.i && *expr == *rhs.expr;
46 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
53 res.i = lhs.i + rhs.i;
54 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
61 res.i = lhs.i - rhs.i;
62 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
69 res.i = lhs.i * rhs.i;
70 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
78 auto divRes = lhs.i / rhs.i;
81 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
82 " Range of division result will be treated as unbounded."
88 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
95 res.i = lhs.i % rhs.i;
96 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
103 res.i = lhs.i & rhs.i;
104 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
111 res.i = lhs.i << rhs.i;
112 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
119 res.i = lhs.i >> rhs.i;
120 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
131 case FeltCmpPredicate::EQ:
132 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
139 case FeltCmpPredicate::NE:
140 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
147 case FeltCmpPredicate::LT:
148 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
150 case FeltCmpPredicate::LE:
151 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
153 case FeltCmpPredicate::GT:
154 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
156 case FeltCmpPredicate::GE:
157 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
167 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
174 res.i =
boolOr(lhs.i, rhs.i);
175 res.expr = solver->mkOr(lhs.expr, rhs.expr);
184 res.expr = solver->mkAnd(
185 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
195 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
196 .Case<
OrFeltOp>([&](
auto _) {
return solver->mkBVOr(lhs.expr, rhs.expr); })
197 .Case<XorFeltOp>([&](
auto _) {
198 return solver->mkBVXor(lhs.expr, rhs.expr);
199 }).Default([&](
auto *unsupported) {
200 llvm::report_fatal_error(
201 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
212 res.expr = solver->mkBVNeg(val.expr);
219 res.expr = solver->mkBVNot(val.expr);
226 res.expr = solver->mkNot(val.expr);
235 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
241 llvm::SMTExprRef invSym = field.
createSymbol(solver, symName.c_str());
242 llvm::SMTExprRef one = solver->mkBitvector(field.
one(), field.
bitWidth());
243 llvm::SMTExprRef prime = solver->mkBitvector(field.
prime(), field.
bitWidth());
244 llvm::SMTExprRef mult = solver->mkBVMul(val.
getExpr(), invSym);
245 llvm::SMTExprRef
mod = solver->mkBVURem(mult, prime);
246 llvm::SMTExprRef constraint = solver->mkEqual(
mod, one);
247 solver->addConstraint(constraint);
249 }).Default([&](Operation *unsupported) {
250 llvm::report_fatal_error(
251 "no fallback provided for " + mlir::Twine(op->getName().getStringRef())
263 os <<
"<null expression>";
266 os <<
" ( interval: " << i <<
" )";
274 llvm::report_fatal_error(
"invalid join lattice type");
276 ChangeResult res = ChangeResult::NoChange;
277 for (
auto &[k, v] : rhs->valMap) {
278 auto it = valMap.find(k);
279 if (it == valMap.end() || it->second != v) {
281 res |= ChangeResult::Change;
284 for (
auto &v : rhs->constraints) {
285 if (!constraints.contains(v)) {
286 constraints.insert(v);
287 res |= ChangeResult::Change;
290 for (
auto &[e, i] : rhs->intervals) {
291 auto it = intervals.find(e);
292 if (it == intervals.end() || it->second != i) {
294 res |= ChangeResult::Change;
301 os <<
"IntervalAnalysisLattice { ";
302 for (
auto &[ref, val] : valMap) {
303 os <<
"\n (valMap) " << ref <<
" := " << val;
305 for (
auto &[expr, interval] : intervals) {
306 os <<
"\n (intervals) ";
312 os <<
" in " << interval;
314 if (!valMap.empty()) {
321 auto it = valMap.find(v);
322 if (it == valMap.end()) {
328FailureOr<IntervalAnalysisLattice::LatticeValue>
330 auto it = fieldMap.find(v);
331 if (it == fieldMap.end()) {
334 auto fit = it->second.find(f);
335 if (fit == it->second.end()) {
343 if (valMap[v] == val) {
344 return ChangeResult::NoChange;
347 intervals[e.getExpr()] = e.getInterval();
348 return ChangeResult::Change;
353 if (fieldMap[v][f] == val) {
354 return ChangeResult::NoChange;
356 fieldMap[v][f] = val;
357 intervals[e.getExpr()] = e.getInterval();
358 return ChangeResult::Change;
362 if (!constraints.contains(e)) {
363 constraints.insert(e);
364 return ChangeResult::Change;
366 return ChangeResult::NoChange;
370 auto it = intervals.find(expr);
371 if (it != intervals.end()) {
378 auto it = intervals.find(expr);
379 if (it != intervals.end() && it->second == i) {
380 return ChangeResult::NoChange;
383 return ChangeResult::Change;
398 if (action == dataflow::CallControlFlowAction::EnterCallee) {
402 setToEntryState(after);
407 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
411 if (
auto *prev = call->getPrevNode()) {
416 ensure(beforeCall,
"could not get prior lattice");
419 propagateIfChanged(after, after->
join(*beforeCall));
425 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
436 Operation *op,
const Lattice &before, Lattice *after
438 ChangeResult changed = after->
join(before);
440 llvm::SmallVector<LatticeValue> operandVals;
443 ensure(constrainRefLattice,
"failed to get lattice");
445 for (OpOperand &operand : op->getOpOperands()) {
446 Value val = operand.get();
448 auto priorState = before.
getValue(val);
449 if (succeeded(priorState) && priorState->getScalarValue().getExpr() !=
nullptr) {
450 operandVals.push_back(*priorState);
457 Type valTy = val.getType();
458 if (mlir::isa<ArrayType, StructType>(valTy) && !
isSignalType(valTy)) {
459 operandVals.push_back(LatticeValue());
464 ensure(refSet.
isScalar(),
"should have ruled out array values already");
470 op->emitWarning() <<
"state of " << val
471 <<
" is empty; defining operation is unsupported by constrain ref analysis";
472 propagateIfChanged(after, changed);
476 debug::Appender(warning) <<
"operand " << val <<
" is not a single value " << refSet
477 <<
", overapproximating";
478 op->emitWarning(warning);
482 changed |= after->
setValue(val, anyVal);
483 operandVals.emplace_back(anyVal);
487 if (succeeded(priorState)) {
488 exprVal = exprVal.
withInterval(priorState->getScalarValue().getInterval());
490 changed |= after->
setValue(val, exprVal);
491 operandVals.emplace_back(exprVal);
497 auto constVal = getConst(op);
498 auto expr = createConstBitvectorExpr(constVal);
500 changed |= after->
setValue(op->getResult(0), latticeVal);
501 }
else if (isArithmeticOp(op)) {
502 ensure(operandVals.size() <= 2,
"arithmetic op with the wrong number of operands");
504 if (operandVals.size() == 2) {
505 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
507 result = performUnaryArithmetic(op, operandVals[0]);
510 changed |= after->
setValue(op->getResult(0), result);
511 }
else if (
EmitEqualityOp emitEq = mlir::dyn_cast<EmitEqualityOp>(op)) {
512 ensure(operandVals.size() == 2,
"constraint op with the wrong number of operands");
513 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
519 auto res = getGeneralizedDecompInterval(constrainRefLattice, lhsVal, rhsVal);
520 if (succeeded(res)) {
521 for (Value signalVal : res->first) {
522 changed |= applyInterval(emitEq, after, signalVal, res->second);
529 changed |= applyInterval(emitEq, after, lhsVal, constraint.
getInterval());
530 changed |= applyInterval(emitEq, after, rhsVal, constraint.
getInterval());
532 }
else if (
AssertOp assertOp = mlir::dyn_cast<AssertOp>(op)) {
533 ensure(operandVals.size() == 1,
"assert op with the wrong number of operands");
536 changed |= applyInterval(
537 assertOp, after, assertOp.getCondition(),
541 auto assertExpr = operandVals[0].getScalarValue();
543 }
else if (
auto readf = mlir::dyn_cast<FieldReadOp>(op)) {
547 changed |= after->
setValue(readf.getVal(), operandVals[0].getScalarValue());
548 }
else if (
auto storedVal =
549 before.
getValue(readf.getComponent(), readf.getFieldNameAttr().getAttr());
550 succeeded(storedVal)) {
552 changed |= after->
setValue(readf.getVal(), storedVal->getScalarValue());
554 }
else if (
auto writef = mlir::dyn_cast<FieldWriteOp>(op)) {
558 after->
setValue(writef.getComponent(), writef.getFieldNameAttr().getAttr(), writeVal);
562 auto fieldDefRes = writef.getFieldDefOp(tables);
563 if (succeeded(fieldDefRes)) {
570 }
else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
572 changed |= after->
setValue(op->getResult(0), operandVals[0].getScalarValue());
573 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
575 Operation *parent = op->getParentOp();
576 ensure(parent,
"yield operation must have parent lattice");
578 ensure(parentLattice,
"could not fetch parent lattice");
580 for (
unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
581 Value parentRes = parent->getResult(idx);
583 auto exprValRes = parentLattice->getValue(parentRes);
585 if (succeeded(exprValRes)) {
592 changed |= after->
setValue(parentRes, newResVal);
595 propagateIfChanged(parentLattice, parentLattice->join(*after));
604 && !isDefinitionOp(op)
606 && !mlir::isa<CreateStructOp>(op)
608 op->emitWarning(
"unhandled operation, analysis may be incomplete");
611 propagateIfChanged(after, changed);
615 auto it = refSymbols.find(r);
616 if (it != refSymbols.end()) {
619 auto sym = createFeltSymbol(r);
624llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const ConstrainRef &r)
const {
628llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v)
const {
632llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(
const char *
name)
const {
633 return field.get().createSymbol(smtSolver,
name);
636llvm::APSInt IntervalDataFlowAnalysis::getConst(Operation *op)
const {
637 ensure(isConstOp(op),
"op is not a const op");
639 llvm::APInt fieldConst =
640 TypeSwitch<Operation *, llvm::APInt>(op)
641 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
642 llvm::APSInt constOpVal(feltConst.getValueAttr().getValue());
643 return field.get().reduce(constOpVal);
645 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
646 return llvm::APInt(field.get().bitWidth(), indexConst.value());
648 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
649 return llvm::APInt(field.get().bitWidth(), intConst.value());
650 }).Default([](Operation *illegalOp) {
652 debug::Appender(err) <<
"unhandled getConst case: " << *illegalOp;
653 llvm::report_fatal_error(Twine(err));
654 return llvm::APInt();
660 Operation *op,
const LatticeValue &a,
const LatticeValue &b
662 ensure(isArithmeticOp(op),
"is not arithmetic op");
664 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
665 ensure(lhs.getExpr(),
"cannot perform arithmetic over null lhs smt expr");
666 ensure(rhs.getExpr(),
"cannot perform arithmetic over null rhs smt expr");
668 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
669 .Case<AddFeltOp>([&](
auto _) {
return add(smtSolver, lhs, rhs); })
670 .Case<SubFeltOp>([&](
auto _) {
return sub(smtSolver, lhs, rhs); })
671 .Case<MulFeltOp>([&](
auto _) {
return mul(smtSolver, lhs, rhs); })
672 .Case<DivFeltOp>([&](
auto divOp) {
return div(smtSolver, divOp, lhs, rhs); })
673 .Case<ModFeltOp>([&](
auto _) {
return mod(smtSolver, lhs, rhs); })
674 .Case<AndFeltOp>([&](
auto _) {
return bitAnd(smtSolver, lhs, rhs); })
675 .Case<ShlFeltOp>([&](
auto _) {
return shiftLeft(smtSolver, lhs, rhs); })
676 .Case<ShrFeltOp>([&](
auto _) {
return shiftRight(smtSolver, lhs, rhs); })
677 .Case<CmpOp>([&](
auto cmpOp) {
return cmp(smtSolver, cmpOp, lhs, rhs); })
678 .Case<AndBoolOp>([&](
auto _) {
return boolAnd(smtSolver, lhs, rhs); })
679 .Case<OrBoolOp>([&](
auto _) {
return boolOr(smtSolver, lhs, rhs); })
680 .Case<XorBoolOp>([&](
auto _) {
681 return boolXor(smtSolver, lhs, rhs);
682 }).Default([&](
auto *unsupported) {
683 unsupported->emitWarning(
684 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
689 ensure(res.getExpr(),
"arithmetic produced null smt expr");
694IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op,
const LatticeValue &a) {
695 ensure(isArithmeticOp(op),
"is not arithmetic op");
697 auto val = a.getScalarValue();
698 ensure(val.getExpr(),
"cannot perform arithmetic over null smt expr");
700 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
701 .Case<NegFeltOp>([&](
auto _) {
return neg(smtSolver, val); })
702 .Case<NotFeltOp>([&](
auto _) {
return notOp(smtSolver, val); })
703 .Case<NotBoolOp>([&](
auto _) {
704 return boolNot(smtSolver, val);
705 }).Default([&](
auto *unsupported) {
706 unsupported->emitWarning(
707 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
712 ensure(res.getExpr(),
"arithmetic produced null smt expr");
716ChangeResult IntervalDataFlowAnalysis::applyInterval(
717 Operation *originalOp, Lattice *after, Value val,
Interval newInterval
719 auto latValRes = after->getValue(val);
720 if (failed(latValRes)) {
722 return ChangeResult::NoChange;
724 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
725 ChangeResult res = after->setValue(val, newLatticeVal);
728 Lattice *valLattice =
nullptr;
729 if (
auto valOp = val.getDefiningOp()) {
733 if (
auto prev = valOp->getPrevNode()) {
734 valLattice = getOrCreate<Lattice>(prev);
736 valLattice = getOrCreate<Lattice>(valOp->getBlock());
738 }
else if (
auto blockArg = mlir::dyn_cast<BlockArgument>(val)) {
739 Operation *owningOp = blockArg.getOwner()->getParentOp();
740 if (propagateInputConstraints) {
742 auto fnOp = dyn_cast<FuncDefOp>(owningOp);
743 if (fnOp && fnOp.isStructConstrain() && blockArg.getArgNumber() > 0 &&
744 !newInterval.isEntire()) {
745 auto structOp = fnOp->getParentOfType<StructDefOp>();
746 FuncDefOp computeFn = structOp.getComputeFuncOp();
747 Operation *computeEntry = &computeFn.getRegion().front().front();
748 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
749 Lattice *computeEntryLattice = getOrCreate<Lattice>(computeEntry);
750 auto entryLatticeVal = computeEntryLattice->getValue(computeArg);
751 ExpressionValue newArgVal;
752 if (succeeded(entryLatticeVal)) {
753 newArgVal = entryLatticeVal->getScalarValue().withInterval(newInterval);
758 newArgVal = ExpressionValue(
nullptr, newInterval);
760 ChangeResult computeRes = computeEntryLattice->setValue(computeArg, newArgVal);
761 propagateIfChanged(computeEntryLattice, computeRes);
765 valLattice = getOrCreate<Lattice>(blockArg.getOwner());
767 valLattice = getOrCreate<Lattice>(val);
769 ensure(valLattice,
"val should have a lattice");
770 auto setNewVal = [&valLattice, &after, &val, &newLatticeVal,
this]() {
771 if (valLattice != after) {
772 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
777 Operation *definingOp = val.getDefiningOp();
783 const Field &f = field.get();
791 auto cmpCase = [&](CmpOp cmpOp) {
795 newInterval.isBoolean(),
796 "new interval for CmpOp outside of allowed boolean range or is empty"
798 if (!newInterval.isDegenerate()) {
800 return ChangeResult::NoChange;
803 bool cmpTrue = newInterval.rhs() == f.one();
805 Value lhs = cmpOp->getOperand(0), rhs = cmpOp->getOperand(1);
806 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
807 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
808 return ChangeResult::NoChange;
810 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
811 rhsExpr = rhsLatValRes->getScalarValue();
813 Interval newLhsInterval, newRhsInterval;
814 const Interval &lhsInterval = lhsExpr.getInterval();
815 const Interval &rhsInterval = rhsExpr.getInterval();
819 auto eqCase = [&]() {
820 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
821 (pred == FeltCmpPredicate::NE && !cmpTrue);
823 auto neCase = [&]() {
824 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
825 (pred == FeltCmpPredicate::EQ && !cmpTrue);
827 auto ltCase = [&]() {
828 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
829 (pred == FeltCmpPredicate::GE && !cmpTrue);
831 auto leCase = [&]() {
832 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
833 (pred == FeltCmpPredicate::GT && !cmpTrue);
835 auto gtCase = [&]() {
836 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
837 (pred == FeltCmpPredicate::LE && !cmpTrue);
839 auto geCase = [&]() {
840 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
841 (pred == FeltCmpPredicate::LT && !cmpTrue);
846 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
847 }
else if (neCase()) {
849 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
853 }
else if (lhsInterval.isDegenerate()) {
855 newLhsInterval = lhsInterval;
856 newRhsInterval = rhsInterval.difference(lhsInterval);
857 }
else if (rhsInterval.isDegenerate()) {
859 newLhsInterval = lhsInterval.difference(rhsInterval);
860 newRhsInterval = rhsInterval;
863 newLhsInterval = lhsInterval;
864 newRhsInterval = rhsInterval;
866 }
else if (ltCase()) {
867 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
868 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
869 }
else if (leCase()) {
870 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
871 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
872 }
else if (gtCase()) {
873 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
874 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
875 }
else if (geCase()) {
876 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
877 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
879 cmpOp->emitWarning(
"unhandled cmp predicate");
880 return ChangeResult::NoChange;
884 return applyInterval(originalOp, after, lhs, newLhsInterval) |
885 applyInterval(originalOp, after, rhs, newRhsInterval);
890 auto mulCase = [&](MulFeltOp mulOp) {
892 if (newInterval.intersect(zeroInt).isNotEmpty()) {
894 return ChangeResult::NoChange;
897 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
898 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
899 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
900 return ChangeResult::NoChange;
902 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
903 rhsExpr = rhsLatValRes->getScalarValue();
904 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
905 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
906 return applyInterval(originalOp, after, lhs, newLhsInterval) |
907 applyInterval(originalOp, after, rhs, newRhsInterval);
913 auto readfCase = [&](FieldReadOp readfOp) {
914 Value comp = readfOp.getComponent();
916 return applyInterval(originalOp, after, comp, newInterval);
918 return ChangeResult::NoChange;
925 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
926 .Case<CmpOp>([&](
auto op) {
return cmpCase(op); })
927 .Case<MulFeltOp>([&](
auto op) {
return mulCase(op); })
928 .Case<FieldReadOp>([&](
auto op){
return readfCase(op); })
929 .Default([&](
auto *_) {
return ChangeResult::NoChange; });
938FailureOr<std::pair<DenseSet<Value>,
Interval>>
939IntervalDataFlowAnalysis::getGeneralizedDecompInterval(
942 auto isZeroConst = [
this](Value v) {
943 Operation *op = v.getDefiningOp();
947 if (!isConstOp(op)) {
950 llvm::APSInt c = getConst(op);
951 return safeEq(c, field.get().zero());
953 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
954 Value exprTree =
nullptr;
955 if (lhsIsZero && !rhsIsZero) {
957 }
else if (!lhsIsZero && rhsIsZero) {
964 std::optional<ConstrainRef> signalRef = std::nullopt;
965 DenseSet<Value> signalVals;
966 SmallVector<APSInt> consts;
967 SmallVector<Value> frontier {exprTree};
968 while (!frontier.empty()) {
969 Value v = frontier.back();
971 Operation *op = v.getDefiningOp();
975 auto handleRefValue = [&constrainRefLattice, &signalRef, &signalVal, &signalVals]() {
976 ConstrainRefLatticeValue refSet = constrainRefLattice->getOrDefault(signalVal);
977 if (!refSet.isScalar() || !refSet.isSingleValue()) {
980 ConstrainRef r = refSet.getSingleValue();
981 if (signalRef.has_value() && signalRef.value() != r) {
983 }
else if (!signalRef.has_value()) {
986 signalVals.insert(signalVal);
991 if (op && matchPattern(op, subPattern)) {
992 if (failed(handleRefValue())) {
995 auto constInt = APSInt(c.getValueAttr().getValue());
996 consts.push_back(field.get().reduce(constInt));
999 if (failed(handleRefValue())) {
1002 consts.push_back(field.get().zero());
1008 if (op && matchPattern(op, mulPattern)) {
1009 frontier.push_back(a);
1010 frontier.push_back(b);
1019 std::sort(consts.begin(), consts.end());
1020 Interval iv =
Interval::TypeA(field.get(), consts.front(), consts.back());
1021 return std::make_pair(std::move(signalVals), iv);
1029 auto computeIntervalsImpl = [&solver, &ctx,
this](
1031 llvm::MapVector<ConstrainRef, Interval> &fieldRanges,
1032 llvm::SetVector<ExpressionValue> &solverConstraints
1035 Operation *fnEnd = fn.getRegion().back().getTerminator();
1044 if (!ref.isScalar() && !ref.isSignal()) {
1048 if (
auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1054 if (succeeded(intervalRes)) {
1055 fieldRanges[ref] = *intervalRes;
1062 computeIntervalsImpl(structDef.getComputeFuncOp(), computeFieldRanges, computeSolverConstraints);
1063 computeIntervalsImpl(
1064 structDef.getConstrainFuncOp(), constrainFieldRanges, constrainSolverConstraints
1071 auto writeIntervals =
1072 [&os, &withConstraints](
1073 const char *fnName,
const llvm::MapVector<ConstrainRef, Interval> &fieldRanges,
1074 const llvm::SetVector<ExpressionValue> &solverConstraints,
bool printName
1079 os.indent(indent) << fnName <<
" {";
1083 if (fieldRanges.empty()) {
1088 for (
auto &[ref, interval] : fieldRanges) {
1090 os.indent(indent) << ref <<
" in " << interval;
1093 if (withConstraints) {
1095 os.indent(indent) <<
"Solver Constraints { ";
1096 if (solverConstraints.empty()) {
1099 for (
const auto &e : solverConstraints) {
1101 os.indent(indent + 4);
1102 e.getExpr()->print(os);
1105 os.indent(indent) <<
'}';
1111 os.indent(indent - 4) <<
'}';
1115 os <<
"StructIntervals { ";
1116 if (constrainFieldRanges.empty() && (!printCompute || computeFieldRanges.empty())) {
1122 writeIntervals(
FUNC_NAME_COMPUTE, computeFieldRanges, computeSolverConstraints, printCompute);
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
const ConstrainRef & getSingleValue() const
A lattice for use in dense analysis.
Defines a reference to a llzk object within a constrain function call.
ConstrainRef createChild(ConstrainRefIndex r) const
static std::vector< ConstrainRef > getAllConstrainRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ConstrainRef root)
Produce all possible ConstraintRefs that are present starting from the given root.
Tracks a solver expression and an interval range for that expression.
const Interval & getInterval() const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
unsigned bitWidth() const
llvm::APSInt prime() const
For the prime field p, returns p.
Maps mlir::Values to LatticeValues.
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, Interval i)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
Intervals over a finite field.
static Interval True(const Field &f)
Interval intersect(const Interval &rhs) const
Intersect.
static Interval Boolean(const Field &f)
bool isDegenerate() const
static Interval False(const Field &f)
static Interval Degenerate(const Field &f, llvm::APSInt val)
Interval join(const Interval &rhs) const
Union.
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, IntervalAnalysisContext &ctx)
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
::llzk::boolean::FeltCmpPredicate getPredicate()
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
bool isSingleValue() const
const ScalarTy & getScalarValue() const
IntervalAnalysisLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
llvm::APSInt safeToSigned(llvm::APSInt i)
Safely converts the given int to a signed int if it is an unsigned int by adding an extra bit for the...
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, llvm::Twine errMsg)
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
llvm::SMTExprRef getSymbol(const ConstrainRef &r)