LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
12#include "llzk/Util/Debug.h"
14
15#include <mlir/Dialect/SCF/IR/SCF.h>
16
17#include <llvm/ADT/TypeSwitch.h>
18
19using namespace mlir;
20
21namespace llzk {
22
23using namespace array;
24using namespace boolean;
25using namespace cast;
26using namespace component;
27using namespace constrain;
28using namespace felt;
29using namespace function;
30
31/* ExpressionValue */
32
34 if (expr == nullptr && rhs.expr == nullptr) {
35 return i == rhs.i;
36 }
37 if (expr == nullptr || rhs.expr == nullptr) {
38 return false;
39 }
40 return i == rhs.i && *expr == *rhs.expr;
41}
42
44boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth) {
45 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
46 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
47 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.getExpr(), one, zero);
48 return expr.withExpression(boolToFeltConv);
49}
50
51ExpressionValue
52intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
53 Interval res = lhs.i.intersect(rhs.i);
54 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
55 return ExpressionValue(exprEq, res);
56}
57
59add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
61 res.i = lhs.i + rhs.i;
62 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
63 return res;
64}
65
67sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
69 res.i = lhs.i - rhs.i;
70 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
71 return res;
72}
73
75mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
77 res.i = lhs.i * rhs.i;
78 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
79 return res;
80}
81
83div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs,
84 const ExpressionValue &rhs) {
86 auto divRes = lhs.i / rhs.i;
87 if (failed(divRes)) {
88 op->emitWarning(
89 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
90 " Range of division result will be treated as unbounded."
91 )
92 .report();
93 res.i = Interval::Entire(lhs.getField());
94 } else {
95 res.i = *divRes;
96 }
97 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
98 return res;
99}
100
102mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
103 ExpressionValue res;
104 res.i = lhs.i % rhs.i;
105 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
106 return res;
107}
108
110bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
111 ExpressionValue res;
112 res.i = lhs.i & rhs.i;
113 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
114 return res;
115}
116
118shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
119 ExpressionValue res;
120 res.i = lhs.i << rhs.i;
121 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
122 return res;
123}
124
126shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
127 ExpressionValue res;
128 res.i = lhs.i >> rhs.i;
129 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
130 return res;
131}
132
134cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs) {
135 ExpressionValue res;
136 const Field &f = lhs.getField();
137 // Default result is any boolean output for when we are unsure about the comparison result.
138 res.i = Interval::Boolean(f);
139 switch (op.getPredicate()) {
140 case FeltCmpPredicate::EQ:
141 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
142 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
143 res.i = lhs.i == rhs.i ? Interval::True(f) : Interval::False(f);
144 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
145 res.i = Interval::False(f);
146 }
147 break;
148 case FeltCmpPredicate::NE:
149 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
150 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
151 res.i = lhs.i != rhs.i ? Interval::True(f) : Interval::False(f);
152 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
153 res.i = Interval::True(f);
154 }
155 break;
156 case FeltCmpPredicate::LT:
157 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
158 break;
159 case FeltCmpPredicate::LE:
160 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
161 break;
162 case FeltCmpPredicate::GT:
163 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
164 break;
165 case FeltCmpPredicate::GE:
166 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
167 break;
168 }
169 return res;
170}
171
173boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
174 ExpressionValue res;
175 res.i = boolAnd(lhs.i, rhs.i);
176 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
177 return res;
178}
179
181boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
182 ExpressionValue res;
183 res.i = boolOr(lhs.i, rhs.i);
184 res.expr = solver->mkOr(lhs.expr, rhs.expr);
185 return res;
186}
187
189boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
190 ExpressionValue res;
191 res.i = boolXor(lhs.i, rhs.i);
192 // There's no Xor, so we do (L || R) && !(L && R)
193 res.expr = solver->mkAnd(
194 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
195 );
196 return res;
197}
198
200 llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs
201) {
202 ExpressionValue res;
203 res.i = Interval::Entire(lhs.getField());
204 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
205 .Case<OrFeltOp>([&](auto) { return solver->mkBVOr(lhs.expr, rhs.expr); })
206 .Case<XorFeltOp>([&](auto) {
207 return solver->mkBVXor(lhs.expr, rhs.expr);
208 }).Default([&](auto *unsupported) {
209 llvm::report_fatal_error(
210 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
211 );
212 return nullptr;
213 });
214
215 return res;
216}
217
218ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val) {
219 ExpressionValue res;
220 res.i = -val.i;
221 res.expr = solver->mkBVNeg(val.expr);
222 return res;
223}
224
225ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val) {
226 ExpressionValue res;
227 res.i = ~val.i;
228 res.expr = solver->mkBVNot(val.expr);
229 return res;
230}
231
232ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val) {
233 ExpressionValue res;
234 res.i = boolNot(val.i);
235 res.expr = solver->mkNot(val.expr);
236 return res;
237}
238
240fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val) {
241 const Field &field = val.getField();
242 ExpressionValue res;
243 res.i = Interval::Entire(field);
244 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
245 .Case<InvFeltOp>([&](auto) {
246 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
247 // To create this expression, we create a new symbol for Y and add the
248 // XY % prime = 1 constraint to the solver.
249 std::string symName = buildStringViaInsertionOp(*op);
250 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
251 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.bitWidth());
252 llvm::SMTExprRef prime = solver->mkBitvector(toAPSInt(field.prime()), field.bitWidth());
253 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
254 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
255 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
256 solver->addConstraint(constraint);
257 return invSym;
258 }).Default([](Operation *unsupported) {
259 llvm::report_fatal_error(
260 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
261 );
262 return nullptr;
263 });
264
265 return res;
266}
267
268void ExpressionValue::print(mlir::raw_ostream &os) const {
269 if (expr) {
270 expr->print(os);
271 } else {
272 os << "<null expression>";
273 }
274
275 os << " ( interval: " << i << " )";
276}
277
278/* IntervalAnalysisLattice */
279
280ChangeResult IntervalAnalysisLattice::join(const AbstractDenseLattice &other) {
281 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
282 if (!rhs) {
283 llvm::report_fatal_error("invalid join lattice type");
284 }
285 ChangeResult res = ChangeResult::NoChange;
286 for (auto &[k, v] : rhs->valMap) {
287 auto it = valMap.find(k);
288 if (it == valMap.end() || it->second != v) {
289 valMap[k] = v;
290 res |= ChangeResult::Change;
291 }
292 }
293 for (auto &v : rhs->constraints) {
294 if (!constraints.contains(v)) {
295 constraints.insert(v);
296 res |= ChangeResult::Change;
297 }
298 }
299 for (auto &[e, i] : rhs->intervals) {
300 auto it = intervals.find(e);
301 if (it == intervals.end() || it->second != i) {
302 intervals[e] = i;
303 res |= ChangeResult::Change;
304 }
305 }
306 return res;
307}
308
309void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
310 os << "IntervalAnalysisLattice { ";
311 for (auto &[ref, val] : valMap) {
312 os << "\n (valMap) " << ref << " := " << val;
313 }
314 for (auto &[expr, interval] : intervals) {
315 os << "\n (intervals) ";
316 if (!expr) {
317 os << "<null expr>";
318 } else {
319 expr->print(os);
320 }
321 os << " in " << interval;
322 }
323 if (!valMap.empty()) {
324 os << '\n';
325 }
326 os << '}';
327}
328
329FailureOr<IntervalAnalysisLattice::LatticeValue> IntervalAnalysisLattice::getValue(Value v) const {
330 auto it = valMap.find(v);
331 if (it == valMap.end()) {
332 return failure();
333 }
334 return it->second;
335}
336
337FailureOr<IntervalAnalysisLattice::LatticeValue>
338IntervalAnalysisLattice::getValue(Value v, StringAttr f) const {
339 auto it = fieldMap.find(v);
340 if (it == fieldMap.end()) {
341 return failure();
342 }
343 auto fit = it->second.find(f);
344 if (fit == it->second.end()) {
345 return failure();
346 }
347 return fit->second;
348}
349
350ChangeResult IntervalAnalysisLattice::setValue(Value v, const LatticeValue &val) {
351 if (valMap[v] == val) {
352 return ChangeResult::NoChange;
353 }
354 valMap[v] = val;
355 ExpressionValue e = val.foldToScalar();
356 intervals[e.getExpr()] = e.getInterval();
357 return ChangeResult::Change;
358}
359
360ChangeResult IntervalAnalysisLattice::setValue(Value v, ExpressionValue e) {
361 LatticeValue val(e);
362 if (valMap[v] == val) {
363 return ChangeResult::NoChange;
364 }
365 valMap[v] = val;
366 intervals[e.getExpr()] = e.getInterval();
367 return ChangeResult::Change;
368}
369
370ChangeResult IntervalAnalysisLattice::setValue(Value v, StringAttr f, ExpressionValue e) {
371 LatticeValue val(e);
372 if (fieldMap[v][f] == val) {
373 return ChangeResult::NoChange;
374 }
375 fieldMap[v][f] = val;
376 intervals[e.getExpr()] = e.getInterval();
377 return ChangeResult::Change;
378}
379
381 if (!constraints.contains(e)) {
382 constraints.insert(e);
383 return ChangeResult::Change;
384 }
385 return ChangeResult::NoChange;
386}
387
388FailureOr<Interval> IntervalAnalysisLattice::findInterval(llvm::SMTExprRef expr) const {
389 auto it = intervals.find(expr);
390 if (it != intervals.end()) {
391 return it->second;
392 }
393 return failure();
394}
395
396ChangeResult IntervalAnalysisLattice::setInterval(llvm::SMTExprRef expr, const Interval &i) {
397 auto it = intervals.find(expr);
398 if (it != intervals.end() && it->second == i) {
399 return ChangeResult::NoChange;
400 }
401 intervals[expr] = i;
402 return ChangeResult::Change;
403}
404
405/* IntervalDataFlowAnalysis */
406
411 CallOpInterface call, dataflow::CallControlFlowAction action,
413) {
417 if (action == dataflow::CallControlFlowAction::EnterCallee) {
418 // We skip updating the incoming lattice for function calls,
419 // as values are relative to the containing function/struct, so we don't need to pollute
420 // the callee with the callers values.
421 setToEntryState(after);
422 }
426 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
427 // Get the argument values of the lattice by getting the state as it would
428 // have been for the callsite.
429 const dataflow::AbstractDenseLattice *beforeCall = getLattice(getProgramPointBefore(call));
430 ensure(beforeCall, "could not get prior lattice");
431
432 // The lattice at the return is the lattice before the call
433 propagateIfChanged(after, after->join(*beforeCall));
434 }
439 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
440 // For external calls, we propagate what information we already have from
441 // before the call to after the call, since the external call won't invalidate
442 // any of that information. It also, conservatively, makes no assumptions about
443 // external calls and their computation, so CDG edges will not be computed over
444 // input arguments to external functions.
445 join(after, before);
446 }
447}
448
449const SourceRefLattice *
450IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) {
451 ProgramPoint *pp = _dataflowSolver.getProgramPointAfter(baseOp);
452 auto defaultSourceRefLattice = _dataflowSolver.lookupState<SourceRefLattice>(pp);
453 ensure(defaultSourceRefLattice, "failed to get lattice");
454 if (Operation *defOp = val.getDefiningOp()) {
455 ProgramPoint *defPoint = _dataflowSolver.getProgramPointAfter(defOp);
456 auto sourceRefLattice = _dataflowSolver.lookupState<SourceRefLattice>(defPoint);
457 ensure(sourceRefLattice, "failed to get SourceRefLattice for value");
458 return sourceRefLattice;
459 }
460 return defaultSourceRefLattice;
461}
462
463mlir::LogicalResult
464IntervalDataFlowAnalysis::visitOperation(Operation *op, const Lattice &before, Lattice *after) {
465 // We only perform the visitation on operations within functions
466 FuncDefOp fn = op->getParentOfType<FuncDefOp>();
467 if (!fn) {
468 return success();
469 }
470
471 ChangeResult changed = ChangeResult::NoChange;
472 // We always propagate the values of the function args from the function
473 // entry as the function context; if the input values are changed, this will
474 // force the recomputation of intervals throughout the function.
475 for (BlockArgument blockArg : fn.getArguments()) {
476 auto blockArgLookupRes = before.getValue(blockArg);
477 if (succeeded(blockArgLookupRes)) {
478 changed |= after->setValue(blockArg, *blockArgLookupRes);
479 }
480 }
481
482 auto getAfter = [&](Value val) {
483 if (Operation *defOp = val.getDefiningOp()) {
484 return getLattice(getProgramPointAfter(defOp));
485 } else if (auto blockArg = dyn_cast<BlockArgument>(val)) {
486 Operation *blockEntry = &blockArg.getOwner()->front();
487 return getLattice(getProgramPointBefore(blockEntry));
488 }
489 return getLattice(getProgramPointBefore(op));
490 };
491
492 llvm::SmallVector<LatticeValue> operandVals;
493 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
494 for (OpOperand &operand : op->getOpOperands()) {
495 Value val = operand.get();
496 SourceRefLatticeValue refSet = getSourceRefLattice(op, val)->getOrDefault(val);
497 if (refSet.isSingleValue()) {
498 operandRefs.push_back(refSet.getSingleValue());
499 } else {
500 operandRefs.push_back(std::nullopt);
501 }
502 // First, lookup the operand value after it is initialized
503 Lattice *valLattice = getAfter(val);
504 auto priorState = valLattice->getValue(val);
505 if (succeeded(priorState) && priorState->getScalarValue().getExpr() != nullptr) {
506 operandVals.push_back(*priorState);
507 changed |= after->setValue(val, *priorState);
508 continue;
509 }
510
511 // Else, look up the stored value by `SourceRef`.
512 // We only care about scalar type values, so we ignore composite types, which
513 // are currently limited to non-Signal structs and arrays.
514 Type valTy = val.getType();
515 if (llvm::isa<ArrayType, StructType>(valTy) && !isSignalType(valTy)) {
516 LatticeValue empty;
517 operandVals.push_back(empty);
518 changed |= after->setValue(val, empty);
519 continue;
520 }
521
522 ensure(refSet.isScalar(), "should have ruled out array values already");
523
524 if (refSet.getScalarValue().empty()) {
525 // If we can't compute the reference, then there must be some unsupported
526 // op the reference analysis cannot handle. We emit a warning and return
527 // early, since there's no meaningful computation we can do for this op.
528 op->emitWarning()
529 .append(
530 "state of ", val, " is empty; defining operation is unsupported by SourceRef analysis"
531 )
532 .report();
533 propagateIfChanged(after, changed);
534 // We still return success so we can return overapproximated and partial
535 // results to the user.
536 return success();
537 } else if (!refSet.isSingleValue()) {
538 std::string warning;
539 debug::Appender(warning) << "operand " << val << " is not a single value " << refSet
540 << ", overapproximating";
541 op->emitWarning(warning).report();
542 // Here, we will override the prior lattice value with a new symbol, representing
543 // "any" value, then use that value for the operands.
544 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
545 changed |= after->setValue(val, anyVal);
546 operandVals.emplace_back(anyVal);
547 } else {
548 const SourceRef &ref = refSet.getSingleValue();
549 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
550 if (succeeded(priorState)) {
551 exprVal = exprVal.withInterval(priorState->getScalarValue().getInterval());
552 }
553 changed |= after->setValue(val, exprVal);
554 operandVals.emplace_back(exprVal);
555 }
556
557 // Since we initialized a value that was not found in the before lattice,
558 // update that value in the lattice so we can find it later, but we don't
559 // need to propagate the changes, since we already have what we need.
560 auto res = after->getValue(val);
561 ensure(succeeded(res), "expected precondition is that value is set");
562 (void)valLattice->setValue(val, *res);
563 }
564
565 // Now, the way we update is dependent on the type of the operation.
566 if (isConstOp(op)) {
567 llvm::DynamicAPInt constVal = getConst(op);
568 llvm::SMTExprRef expr = createConstBitvectorExpr(constVal);
569 ExpressionValue latticeVal(field.get(), expr, constVal);
570 changed |= after->setValue(op->getResult(0), latticeVal);
571 } else if (isArithmeticOp(op)) {
572 ensure(operandVals.size() <= 2, "arithmetic op with the wrong number of operands");
573 ExpressionValue result;
574 if (operandVals.size() == 2) {
575 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
576 } else {
577 result = performUnaryArithmetic(op, operandVals[0]);
578 }
579
580 changed |= after->setValue(op->getResult(0), result);
581 } else if (EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
582 ensure(operandVals.size() == 2, "constraint op with the wrong number of operands");
583 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
584 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
585 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
586
587 // Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
588 // These patterns enforce that s is one of c0, ..., cN.
589 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
590 if (succeeded(res)) {
591 for (Value signalVal : res->first) {
592 changed |= applyInterval(emitEq, after, getAfter(signalVal), signalVal, res->second);
593 }
594 }
595
596 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
597 // Update the LHS and RHS to the same value, but restricted intervals
598 // based on the constraints.
599 const Interval &constrainInterval = constraint.getInterval();
600 changed |= applyInterval(emitEq, after, getAfter(lhsVal), lhsVal, constrainInterval);
601 changed |= applyInterval(emitEq, after, getAfter(rhsVal), rhsVal, constrainInterval);
602 changed |= after->addSolverConstraint(constraint);
603 } else if (AssertOp assertOp = llvm::dyn_cast<AssertOp>(op)) {
604 ensure(operandVals.size() == 1, "assert op with the wrong number of operands");
605 // assert enforces that the operand is true. So we apply an interval of [1, 1]
606 // to the operand.
607 changed |= applyInterval(
608 assertOp, after, after, assertOp.getCondition(),
609 Interval::Degenerate(field.get(), field.get().one())
610 );
611 // Also add the solver constraint that the expression must be true.
612 auto assertExpr = operandVals[0].getScalarValue();
613 changed |= after->addSolverConstraint(assertExpr);
614 } else if (auto readf = llvm::dyn_cast<FieldReadOp>(op)) {
615 Value cmp = readf.getComponent();
616 if (isSignalType(cmp.getType())) {
617 // The reg value read from the signal type is equal to the value of the Signal
618 // struct overall.
619 changed |= after->setValue(readf.getVal(), operandVals[0].getScalarValue());
620 } else {
621 auto storedVal = getAfter(cmp)->getValue(cmp, readf.getFieldNameAttr().getAttr());
622 if (succeeded(storedVal)) {
623 // The result value is the value previously written to this field.
624 changed |= after->setValue(readf.getVal(), storedVal->getScalarValue());
625 } else if (operandRefs[0].has_value()) {
626 // Initialize the value
627 auto fieldDefRes = readf.getFieldDefOp(tables);
628 if (succeeded(fieldDefRes)) {
629 SourceRef ref = operandRefs[0]->createChild(SourceRefIndex(*fieldDefRes));
630 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
631 changed |= after->setValue(readf.getVal(), exprVal);
632 }
633 }
634 }
635 } else if (auto writef = llvm::dyn_cast<FieldWriteOp>(op)) {
636 // Update values stored in a field
637 ExpressionValue writeVal = operandVals[1].getScalarValue();
638 auto cmp = writef.getComponent();
639 changed |= after->setValue(cmp, writef.getFieldNameAttr().getAttr(), writeVal);
640 // We also need to update the interval on the assigned symbol
641 SourceRefLatticeValue refSet = getSourceRefLattice(op, cmp)->getOrDefault(cmp);
642 if (refSet.isSingleValue()) {
643 auto fieldDefRes = writef.getFieldDefOp(tables);
644 if (succeeded(fieldDefRes)) {
645 SourceRefIndex idx(fieldDefRes.value());
646 SourceRef fieldRef = refSet.getSingleValue().createChild(idx);
647 llvm::SMTExprRef expr = getOrCreateSymbol(fieldRef);
648 changed |= after->setInterval(expr, writeVal.getInterval());
649 }
650 }
651 } else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
652 // Casts don't modify the intervals, but they do modify the SMT types.
653 ExpressionValue expr = operandVals[0].getScalarValue();
654 // We treat all ints and indexes as felts with the exception of comparison
655 // results, which are bools. So if `expr` is a bool, this cast needs to
656 // upcast to a felt.
657 if (expr.isBoolSort(smtSolver)) {
658 expr = boolToFelt(smtSolver, expr, field.get().bitWidth());
659 }
660 changed |= after->setValue(op->getResult(0), expr);
661 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
662 // Fetch the lattice for after the parent operation so we can propagate
663 // the yielded value to subsequent operations.
664 Operation *parent = op->getParentOp();
665 ensure(parent, "yield operation must have parent operation");
666 auto postYieldLattice = getLattice(getProgramPointAfter(parent));
667 ensure(postYieldLattice, "could not fetch post-yield lattice");
668 // Bind the operand values to the result values of the parent
669 for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
670 Value parentRes = parent->getResult(idx);
671 // Merge with the existing value, if present (e.g., another branch)
672 // has possible value that must be merged.
673 auto exprValRes = postYieldLattice->getValue(parentRes);
674 ExpressionValue newResVal = operandVals[idx].getScalarValue();
675 if (succeeded(exprValRes)) {
676 ExpressionValue existingVal = exprValRes->getScalarValue();
677 newResVal =
678 existingVal.withInterval(existingVal.getInterval().join(newResVal.getInterval()));
679 } else {
680 newResVal = ExpressionValue(createFeltSymbol(parentRes), newResVal.getInterval());
681 }
682 changed |= after->setValue(parentRes, newResVal);
683 }
684
685 propagateIfChanged(postYieldLattice, postYieldLattice->join(*after));
686 } else if (
687 // We do not need to explicitly handle read ops since they are resolved at the operand value
688 // step where `SourceRef`s are queries (with the exception of the Signal struct, see above).
689 !isReadOp(op)
690 // We do not currently handle return ops as the analysis is currently limited to constrain
691 // functions, which return no value.
692 && !isReturnOp(op)
693 // The analysis ignores definition ops.
694 && !isDefinitionOp(op)
695 // We do not need to analyze the creation of structs.
696 && !llvm::isa<CreateStructOp>(op)
697 ) {
698 op->emitWarning("unhandled operation, analysis may be incomplete").report();
699 }
700
701 propagateIfChanged(after, changed);
702 return success();
703}
704
706 auto it = refSymbols.find(r);
707 if (it != refSymbols.end()) {
708 return it->second;
709 }
710 llvm::SMTExprRef sym = createFeltSymbol(r);
711 refSymbols[r] = sym;
712 return sym;
713}
714
715llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const SourceRef &r) const {
716 return createFeltSymbol(buildStringViaPrint(r).c_str());
717}
718
719llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v) const {
720 return createFeltSymbol(buildStringViaPrint(v).c_str());
721}
722
723llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) const {
724 return field.get().createSymbol(smtSolver, name);
725}
726
727llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
728 ensure(isConstOp(op), "op is not a const op");
729
730 llvm::DynamicAPInt fieldConst =
731 TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
732 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
733 llvm::APSInt constOpVal(feltConst.getValue());
734 return field.get().reduce(constOpVal);
735 })
736 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
737 return DynamicAPInt(indexConst.value());
738 })
739 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
740 return DynamicAPInt(intConst.value());
741 }).Default([](Operation *illegalOp) {
742 std::string err;
743 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
744 llvm::report_fatal_error(Twine(err));
745 return llvm::DynamicAPInt();
746 });
747 return fieldConst;
748}
749
750ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
751 Operation *op, const LatticeValue &a, const LatticeValue &b
752) {
753 ensure(isArithmeticOp(op), "is not arithmetic op");
754
755 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
756 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
757 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
758
759 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
760 .Case<AddFeltOp>([&](auto _) { return add(smtSolver, lhs, rhs); })
761 .Case<SubFeltOp>([&](auto _) { return sub(smtSolver, lhs, rhs); })
762 .Case<MulFeltOp>([&](auto _) { return mul(smtSolver, lhs, rhs); })
763 .Case<DivFeltOp>([&](auto divOp) { return div(smtSolver, divOp, lhs, rhs); })
764 .Case<ModFeltOp>([&](auto _) { return mod(smtSolver, lhs, rhs); })
765 .Case<AndFeltOp>([&](auto _) { return bitAnd(smtSolver, lhs, rhs); })
766 .Case<ShlFeltOp>([&](auto _) { return shiftLeft(smtSolver, lhs, rhs); })
767 .Case<ShrFeltOp>([&](auto _) { return shiftRight(smtSolver, lhs, rhs); })
768 .Case<CmpOp>([&](auto cmpOp) { return cmp(smtSolver, cmpOp, lhs, rhs); })
769 .Case<AndBoolOp>([&](auto _) { return boolAnd(smtSolver, lhs, rhs); })
770 .Case<OrBoolOp>([&](auto _) { return boolOr(smtSolver, lhs, rhs); })
771 .Case<XorBoolOp>([&](auto _) {
772 return boolXor(smtSolver, lhs, rhs);
773 }).Default([&](auto *unsupported) {
774 unsupported
775 ->emitWarning(
776 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
777 )
778 .report();
779 return fallbackBinaryOp(smtSolver, unsupported, lhs, rhs);
780 });
781
782 ensure(res.getExpr(), "arithmetic produced null smt expr");
783 return res;
784}
785
787IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
788 ensure(isArithmeticOp(op), "is not arithmetic op");
789
790 auto val = a.getScalarValue();
791 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
792
793 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
794 .Case<NegFeltOp>([&](auto _) { return neg(smtSolver, val); })
795 .Case<NotFeltOp>([&](auto _) { return notOp(smtSolver, val); })
796 .Case<NotBoolOp>([&](auto _) { return boolNot(smtSolver, val); })
797 // The inverse op is currently overapproximated
798 .Case<InvFeltOp>([&](auto inv) {
799 return fallbackUnaryOp(smtSolver, inv, val);
800 }).Default([&](auto *unsupported) {
801 unsupported
802 ->emitWarning(
803 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
804 )
805 .report();
806 return fallbackUnaryOp(smtSolver, unsupported, val);
807 });
808
809 ensure(res.getExpr(), "arithmetic produced null smt expr");
810 return res;
811}
812
813ChangeResult IntervalDataFlowAnalysis::applyInterval(
814 Operation *originalOp, Lattice *originalLattice, Lattice *after, Value val, Interval newInterval
815) {
816 auto latValRes = after->getValue(val);
817 if (failed(latValRes)) {
818 // visitOperation didn't add val to the lattice, so there's nothing to do
819 return ChangeResult::NoChange;
820 }
821 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
822 propagateIfChanged(after, after->setValue(val, newLatticeVal));
823 ChangeResult res = originalLattice->setValue(val, newLatticeVal);
824 // To allow the dataflow analysis to do its fixed-point iteration, we need to
825 // add the new expression to val's lattice as well.
826 Lattice *valLattice = nullptr;
827 if (Operation *valOp = val.getDefiningOp()) {
828 valLattice = getLattice(getProgramPointAfter(valOp));
829 } else if (auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
830 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
831 Operation *blockEntry = &blockArg.getOwner()->front();
832
833 // Apply the interval from the constrain function inputs to the compute function inputs
834 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
835 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
836 auto structOp = fnOp->getParentOfType<StructDefOp>();
837 FuncDefOp computeFn = structOp.getComputeFuncOp();
838 Operation *computeEntry = &computeFn.getRegion().front().front();
839 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
840 Lattice *computeEntryLattice = getLattice(getProgramPointBefore(computeEntry));
841
842 SourceRef ref(computeArg);
843 ExpressionValue newArgVal(getOrCreateSymbol(ref), newInterval);
844 ChangeResult computeRes = computeEntryLattice->setValue(computeArg, newArgVal);
845 propagateIfChanged(computeEntryLattice, computeRes);
846 }
847
848 valLattice = getLattice(getProgramPointBefore(blockEntry));
849 } else {
850 valLattice = getLattice(val);
851 }
852
853 ensure(valLattice, "val should have a lattice");
854 auto setNewVal = [&valLattice, &val, &newLatticeVal, &res, this]() {
855 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
856 return res;
857 };
858
859 // Now we descend into val's operands, if it has any.
860 Operation *definingOp = val.getDefiningOp();
861 if (!definingOp) {
862 return setNewVal();
863 }
864 Lattice *definingOpLattice = getLattice(getProgramPointAfter(definingOp));
865 auto getOperandLattice = [&](Value operand) {
866 if (Operation *defOp = operand.getDefiningOp()) {
867 return getLattice(getProgramPointAfter(defOp));
868 } else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
869 Operation *blockEntry = &blockArg.getOwner()->front();
870 return getLattice(getProgramPointBefore(blockEntry));
871 }
872 return definingOpLattice;
873 };
874 auto getOperandLatticeVal = [&](Value operand) {
875 return getOperandLattice(operand)->getValue(operand);
876 };
877
878 const Field &f = field.get();
879
880 // This is a rules-based operation. If we have a rule for a given operation,
881 // then we can make some kind of update, otherwise we leave the intervals
882 // as is.
883 // - First we'll define all the rules so the type switch can be less messy
884
885 // cmp.<pred> restricts each side of the comparison if the result is known.
886 auto cmpCase = [&](CmpOp cmpOp) {
887 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
888 // either "true" (1) or "false" (0)
889 ensure(
890 newInterval.isBoolean(),
891 "new interval for CmpOp outside of allowed boolean range or is empty"
892 );
893 if (!newInterval.isDegenerate()) {
894 // The comparison result is unknown, so we can't update the operand ranges
895 return ChangeResult::NoChange;
896 }
897
898 bool cmpTrue = newInterval.rhs() == f.one();
899
900 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
901 auto lhsLatValRes = getOperandLatticeVal(lhs), rhsLatValRes = getOperandLatticeVal(rhs);
902 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
903 return ChangeResult::NoChange;
904 }
905 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
906 rhsExpr = rhsLatValRes->getScalarValue();
907
908 Interval newLhsInterval, newRhsInterval;
909 const Interval &lhsInterval = lhsExpr.getInterval();
910 const Interval &rhsInterval = rhsExpr.getInterval();
911
912 FeltCmpPredicate pred = cmpOp.getPredicate();
913 // predicate cases
914 auto eqCase = [&]() {
915 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
916 (pred == FeltCmpPredicate::NE && !cmpTrue);
917 };
918 auto neCase = [&]() {
919 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
920 (pred == FeltCmpPredicate::EQ && !cmpTrue);
921 };
922 auto ltCase = [&]() {
923 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
924 (pred == FeltCmpPredicate::GE && !cmpTrue);
925 };
926 auto leCase = [&]() {
927 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
928 (pred == FeltCmpPredicate::GT && !cmpTrue);
929 };
930 auto gtCase = [&]() {
931 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
932 (pred == FeltCmpPredicate::LE && !cmpTrue);
933 };
934 auto geCase = [&]() {
935 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
936 (pred == FeltCmpPredicate::LT && !cmpTrue);
937 };
938
939 // new intervals based on case
940 if (eqCase()) {
941 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
942 } else if (neCase()) {
943 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
944 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
945 // an empty value range.
946 newLhsInterval = newRhsInterval = Interval::Empty(f);
947 } else if (lhsInterval.isDegenerate()) {
948 // rhs must not overlap with lhs
949 newLhsInterval = lhsInterval;
950 newRhsInterval = rhsInterval.difference(lhsInterval);
951 } else if (rhsInterval.isDegenerate()) {
952 // lhs must not overlap with rhs
953 newLhsInterval = lhsInterval.difference(rhsInterval);
954 newRhsInterval = rhsInterval;
955 } else {
956 // Leave unchanged
957 newLhsInterval = lhsInterval;
958 newRhsInterval = rhsInterval;
959 }
960 } else if (ltCase()) {
961 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
962 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
963 } else if (leCase()) {
964 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
965 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
966 } else if (gtCase()) {
967 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
968 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
969 } else if (geCase()) {
970 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
971 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
972 } else {
973 cmpOp->emitWarning("unhandled cmp predicate").report();
974 return ChangeResult::NoChange;
975 }
976
977 // Now we recurse to each operand
978 return applyInterval(originalOp, originalLattice, getOperandLattice(lhs), lhs, newLhsInterval) |
979 applyInterval(originalOp, originalLattice, getOperandLattice(rhs), rhs, newRhsInterval);
980 };
981
982 // If the result of a multiplication is non-zero, then both operands must be
983 // non-zero.
984 auto mulCase = [&](MulFeltOp mulOp) {
985 auto zeroInt = Interval::Degenerate(f, f.zero());
986 if (newInterval.intersect(zeroInt).isNotEmpty()) {
987 // The multiplication may be zero, so we can't reduce the operands to be non-zero
988 return ChangeResult::NoChange;
989 }
990
991 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
992 auto lhsLatValRes = getOperandLatticeVal(lhs), rhsLatValRes = getOperandLatticeVal(rhs);
993 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
994 return ChangeResult::NoChange;
995 }
996 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
997 rhsExpr = rhsLatValRes->getScalarValue();
998 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
999 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1000 return applyInterval(originalOp, originalLattice, getOperandLattice(lhs), lhs, newLhsInterval) |
1001 applyInterval(originalOp, originalLattice, getOperandLattice(rhs), rhs, newRhsInterval);
1002 };
1003
1004 // We have a special case for the Signal struct: if this value is created
1005 // from reading a Signal struct's reg field, we also apply the interval to
1006 // the struct itself.
1007 auto readfCase = [&](FieldReadOp readfOp) {
1008 Value comp = readfOp.getComponent();
1009 if (isSignalType(comp.getType())) {
1010 return applyInterval(originalOp, originalLattice, getOperandLattice(comp), comp, newInterval);
1011 }
1012 return ChangeResult::NoChange;
1013 };
1014
1015 // For casts, just pass the interval along to the cast's operand.
1016 auto castCase = [&](Operation *op) {
1017 Value operand = op->getOperand(0);
1018 return applyInterval(
1019 originalOp, originalLattice, getOperandLattice(operand), operand, newInterval
1020 );
1021 };
1022
1023 // - Apply the rules given the op.
1024 // NOTE: disabling clang-format for this because it makes the last case statement
1025 // look ugly.
1026 // clang-format off
1027 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
1028 .Case<CmpOp>([&](auto op) { return cmpCase(op); })
1029 .Case<MulFeltOp>([&](auto op) { return mulCase(op); })
1030 .Case<FieldReadOp>([&](auto op){ return readfCase(op); })
1031 .Case<IntToFeltOp, FeltToIndexOp>([&](auto op) { return castCase(op); })
1032 .Default([&](Operation *) { return ChangeResult::NoChange; });
1033 // clang-format on
1034
1035 // Set the new val after recursion to avoid having recursive calls unset the value.
1036 return setNewVal();
1037}
1038
1039FailureOr<std::pair<DenseSet<Value>, Interval>>
1040IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
1041 auto isZeroConst = [this](Value v) {
1042 Operation *op = v.getDefiningOp();
1043 if (!op) {
1044 return false;
1045 }
1046 if (!isConstOp(op)) {
1047 return false;
1048 }
1049 return getConst(op) == field.get().zero();
1050 };
1051 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1052 Value exprTree = nullptr;
1053 if (lhsIsZero && !rhsIsZero) {
1054 exprTree = rhs;
1055 } else if (!lhsIsZero && rhsIsZero) {
1056 exprTree = lhs;
1057 } else {
1058 return failure();
1059 }
1060
1061 // We now explore the expression tree for multiplications of subtractions/signal values.
1062 std::optional<SourceRef> signalRef = std::nullopt;
1063 DenseSet<Value> signalVals;
1064 SmallVector<DynamicAPInt> consts;
1065 SmallVector<Value> frontier {exprTree};
1066 while (!frontier.empty()) {
1067 Value v = frontier.back();
1068 frontier.pop_back();
1069 Operation *op = v.getDefiningOp();
1070
1071 FeltConstantOp c;
1072 Value signalVal;
1073 auto handleRefValue = [this, &baseOp, &signalRef, &signalVal, &signalVals]() {
1074 SourceRefLatticeValue refSet =
1075 getSourceRefLattice(baseOp, signalVal)->getOrDefault(signalVal);
1076 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1077 return failure();
1078 }
1079 SourceRef r = refSet.getSingleValue();
1080 if (signalRef.has_value() && signalRef.value() != r) {
1081 return failure();
1082 } else if (!signalRef.has_value()) {
1083 signalRef = r;
1084 }
1085 signalVals.insert(signalVal);
1086 return success();
1087 };
1088
1089 auto subPattern = m_CommutativeOp<SubFeltOp>(m_RefValue(&signalVal), m_Constant(&c));
1090 if (op && matchPattern(op, subPattern)) {
1091 if (failed(handleRefValue())) {
1092 return failure();
1093 }
1094 auto constInt = APSInt(c.getValue());
1095 consts.push_back(field.get().reduce(constInt));
1096 continue;
1097 } else if (m_RefValue(&signalVal).match(v)) {
1098 if (failed(handleRefValue())) {
1099 return failure();
1100 }
1101 consts.push_back(field.get().zero());
1102 continue;
1103 }
1104
1105 Value a, b;
1106 auto mulPattern = m_CommutativeOp<MulFeltOp>(matchers::m_Any(&a), matchers::m_Any(&b));
1107 if (op && matchPattern(op, mulPattern)) {
1108 frontier.push_back(a);
1109 frontier.push_back(b);
1110 continue;
1111 }
1112
1113 return failure();
1114 }
1115
1116 // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4),
1117 // we will create a larger range of [0, 4], since we don't support multiple intervals.
1118 std::sort(consts.begin(), consts.end());
1119 Interval iv = Interval::TypeA(field.get(), consts.front(), consts.back());
1120 return std::make_pair(std::move(signalVals), iv);
1121}
1122
1123/* StructIntervals */
1124
1125static void getReversedOps(Region *r, llvm::SmallVector<Operation *> &opList) {
1126 for (Block &b : llvm::reverse(*r)) {
1127 for (Operation &op : llvm::reverse(b)) {
1128 for (Region &nested : llvm::reverse(op.getRegions())) {
1129 getReversedOps(&nested, opList);
1130 }
1131 opList.push_back(&op);
1132 }
1133 }
1134}
1135
1137 mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx
1138) {
1139
1140 auto computeIntervalsImpl = [&solver, &ctx, this](
1141 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &fieldRanges,
1142 llvm::SetVector<ExpressionValue> &solverConstraints
1143 ) {
1144 // Since every lattice value does not contain every value, we will traverse
1145 // the function backwards (from most up-to-date to least-up-to-date lattices)
1146 // searching for the source refs. Once a source ref is found, we remove it
1147 // from the search set.
1148
1149 SourceRefSet searchSet;
1150 for (const auto &ref : SourceRef::getAllSourceRefs(structDef, fn)) {
1151 // We only want to compute intervals for field elements and not composite types,
1152 // with the exception of the Signal struct.
1153 if (!ref.isScalar() && !ref.isSignal()) {
1154 continue;
1155 }
1156 // We also don't want to show the interval for a Signal and its internal reg.
1157 if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1158 continue;
1159 }
1160 searchSet.insert(ref);
1161 }
1162
1163 // Get all ops in reverse order, including nested ops.
1164 llvm::SmallVector<Operation *> opList;
1165 getReversedOps(&fn.getBody(), opList);
1166
1167 // Also traverse the function op itself
1168 opList.push_back(fn);
1169
1170 for (Operation *op : opList) {
1171 ProgramPoint *pp = solver.getProgramPointAfter(op);
1172 const IntervalAnalysisLattice *lattice = solver.lookupState<IntervalAnalysisLattice>(pp);
1173 const auto &c = lattice->getConstraints();
1174 solverConstraints.insert(c.begin(), c.end());
1175
1176 SourceRefSet newSearchSet;
1177 for (const auto &ref : searchSet) {
1178 auto symbol = ctx.getSymbol(ref);
1179 auto intervalRes = lattice->findInterval(symbol);
1180 if (succeeded(intervalRes)) {
1181 fieldRanges[ref] = *intervalRes;
1182 } else {
1183 newSearchSet.insert(ref);
1184 }
1185 }
1186 searchSet = newSearchSet;
1187 }
1188
1189 // For all unfound refs, default to the entire range.
1190 for (const auto &ref : searchSet) {
1191 fieldRanges[ref] = Interval::Entire(ctx.getField());
1192 }
1193
1194 // Sort the outputs since we assembled things out of order.
1195 llvm::sort(fieldRanges, [](auto a, auto b) { return std::get<0>(a) < std::get<0>(b); });
1196 };
1197
1198 computeIntervalsImpl(structDef.getComputeFuncOp(), computeFieldRanges, computeSolverConstraints);
1199 computeIntervalsImpl(
1200 structDef.getConstrainFuncOp(), constrainFieldRanges, constrainSolverConstraints
1201 );
1202
1203 return success();
1204}
1205
1206void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints, bool printCompute) const {
1207 auto writeIntervals =
1208 [&os, &withConstraints](
1209 const char *fnName, const llvm::MapVector<SourceRef, Interval> &fieldRanges,
1210 const llvm::SetVector<ExpressionValue> &solverConstraints, bool printName
1211 ) {
1212 int indent = 4;
1213 if (printName) {
1214 os << '\n';
1215 os.indent(indent) << fnName << " {";
1216 indent += 4;
1217 }
1218
1219 if (fieldRanges.empty()) {
1220 os << "}\n";
1221 return;
1222 }
1223
1224 for (auto &[ref, interval] : fieldRanges) {
1225 os << '\n';
1226 os.indent(indent) << ref << " in " << interval;
1227 }
1228
1229 if (withConstraints) {
1230 os << "\n\n";
1231 os.indent(indent) << "Solver Constraints { ";
1232 if (solverConstraints.empty()) {
1233 os << "}\n";
1234 } else {
1235 for (const auto &e : solverConstraints) {
1236 os << '\n';
1237 os.indent(indent + 4);
1238 e.getExpr()->print(os);
1239 }
1240 os << '\n';
1241 os.indent(indent) << '}';
1242 }
1243 }
1244
1245 if (printName) {
1246 os << '\n';
1247 os.indent(indent - 4) << '}';
1248 }
1249 };
1250
1251 os << "StructIntervals { ";
1252 if (constrainFieldRanges.empty() && (!printCompute || computeFieldRanges.empty())) {
1253 os << "}\n";
1254 return;
1255 }
1256
1257 if (printCompute) {
1258 writeIntervals(FUNC_NAME_COMPUTE, computeFieldRanges, computeSolverConstraints, printCompute);
1259 }
1260 writeIntervals(
1261 FUNC_NAME_CONSTRAIN, constrainFieldRanges, constrainSolverConstraints, printCompute
1262 );
1263
1264 os << "\n}\n";
1265}
1266
1267} // namespace llzk
MlirStringRef name
Definition Poly.cpp:48
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
bool isBoolSort(llvm::SMTSolverRef solver) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:25
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
Definition Field.h:64
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:37
unsigned bitWidth() const
Definition Field.h:61
Maps mlir::Values to LatticeValues.
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult setValue(mlir::Value v, const LatticeValue &val)
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i)
mlir::LogicalResult visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:304
static Interval True(const Field &f)
Definition Intervals.h:219
Interval intersect(const Interval &rhs) const
Intersect.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
bool isDegenerate() const
Definition Intervals.h:306
static Interval False(const Field &f)
Definition Intervals.h:217
static Interval Degenerate(const Field &f, const llvm::DynamicAPInt &val)
Definition Intervals.h:213
Interval join(const Interval &rhs) const
Union.
Defines an index into an LLZK object.
Definition SourceRef.h:36
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
A lattice for use in dense analysis.
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
SourceRef createChild(SourceRefIndex r) const
Definition SourceRef.h:252
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:601
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
IntervalAnalysisLattice * getLattice(mlir::LatticeAnchor anchor) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
auto m_RefValue()
Definition Matchers.h:68
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_Constant()
Definition Matchers.h:88
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
Definition Matchers.h:47
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
llvm::SMTExprRef getSymbol(const SourceRef &r) const