LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13#include "llzk/Util/Debug.h"
15
16#include <mlir/Dialect/SCF/IR/SCF.h>
17
18#include <llvm/ADT/TypeSwitch.h>
19
20using namespace mlir;
21
22namespace llzk {
23
24using namespace array;
25using namespace boolean;
26using namespace cast;
27using namespace component;
28using namespace constrain;
29using namespace felt;
30using namespace function;
31
32/* ExpressionValue */
33
35 if (expr == nullptr && rhs.expr == nullptr) {
36 return i == rhs.i;
37 }
38 if (expr == nullptr || rhs.expr == nullptr) {
39 return false;
40 }
41 return i == rhs.i && *expr == *rhs.expr;
42}
43
45boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth) {
46 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
47 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
48 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.getExpr(), one, zero);
49 return expr.withExpression(boolToFeltConv);
50}
51
52ExpressionValue
53intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
54 Interval res = lhs.i.intersect(rhs.i);
55 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
56 return ExpressionValue(exprEq, res);
57}
58
60add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
62 res.i = lhs.i + rhs.i;
63 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
64 return res;
65}
66
68sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
70 res.i = lhs.i - rhs.i;
71 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
72 return res;
73}
74
76mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
78 res.i = lhs.i * rhs.i;
79 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
80 return res;
81}
82
84div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs,
85 const ExpressionValue &rhs) {
87 auto divRes = lhs.i / rhs.i;
88 if (failed(divRes)) {
89 op->emitWarning(
90 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
91 " Range of division result will be treated as unbounded."
92 )
93 .report();
94 res.i = Interval::Entire(lhs.getField());
95 } else {
96 res.i = *divRes;
97 }
98 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
99 return res;
100}
101
103mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
104 ExpressionValue res;
105 res.i = lhs.i % rhs.i;
106 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
107 return res;
108}
109
111bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
112 ExpressionValue res;
113 res.i = lhs.i & rhs.i;
114 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
115 return res;
116}
117
119shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
120 ExpressionValue res;
121 res.i = lhs.i << rhs.i;
122 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
123 return res;
124}
125
127shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
128 ExpressionValue res;
129 res.i = lhs.i >> rhs.i;
130 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
131 return res;
132}
133
135cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs) {
136 ExpressionValue res;
137 const Field &f = lhs.getField();
138 // Default result is any boolean output for when we are unsure about the comparison result.
139 res.i = Interval::Boolean(f);
140 switch (op.getPredicate()) {
141 case FeltCmpPredicate::EQ:
142 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
143 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
144 res.i = lhs.i == rhs.i ? Interval::True(f) : Interval::False(f);
145 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
146 res.i = Interval::False(f);
147 }
148 break;
149 case FeltCmpPredicate::NE:
150 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
151 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
152 res.i = lhs.i != rhs.i ? Interval::True(f) : Interval::False(f);
153 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
154 res.i = Interval::True(f);
155 }
156 break;
157 case FeltCmpPredicate::LT:
158 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
159 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
160 res.i = Interval::True(f);
161 }
162 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
163 res.i = Interval::False(f);
164 }
165 break;
166 case FeltCmpPredicate::LE:
167 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
168 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
169 res.i = Interval::True(f);
170 }
171 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
172 res.i = Interval::False(f);
173 }
174 break;
175 case FeltCmpPredicate::GT:
176 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
177 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
178 res.i = Interval::True(f);
179 }
180 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
181 res.i = Interval::False(f);
182 }
183 break;
184 case FeltCmpPredicate::GE:
185 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
186 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
187 res.i = Interval::True(f);
188 }
189 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
190 res.i = Interval::False(f);
191 }
192 break;
193 }
194 return res;
195}
196
198boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
199 ExpressionValue res;
200 res.i = boolAnd(lhs.i, rhs.i);
201 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
202 return res;
203}
204
206boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
207 ExpressionValue res;
208 res.i = boolOr(lhs.i, rhs.i);
209 res.expr = solver->mkOr(lhs.expr, rhs.expr);
210 return res;
211}
212
214boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
215 ExpressionValue res;
216 res.i = boolXor(lhs.i, rhs.i);
217 // There's no Xor, so we do (L || R) && !(L && R)
218 res.expr = solver->mkAnd(
219 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
220 );
221 return res;
222}
223
225 llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs
226) {
227 ExpressionValue res;
228 res.i = Interval::Entire(lhs.getField());
229 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
230 .Case<OrFeltOp>([&](auto) { return solver->mkBVOr(lhs.expr, rhs.expr); })
231 .Case<XorFeltOp>([&](auto) {
232 return solver->mkBVXor(lhs.expr, rhs.expr);
233 }).Default([&](auto *unsupported) {
234 llvm::report_fatal_error(
235 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
236 );
237 return nullptr;
238 });
239
240 return res;
241}
242
243ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val) {
244 ExpressionValue res;
245 res.i = -val.i;
246 res.expr = solver->mkBVNeg(val.expr);
247 return res;
248}
249
250ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val) {
251 ExpressionValue res;
252 res.i = ~val.i;
253 res.expr = solver->mkBVNot(val.expr);
254 return res;
255}
256
257ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val) {
258 ExpressionValue res;
259 res.i = boolNot(val.i);
260 res.expr = solver->mkNot(val.expr);
261 return res;
262}
263
265fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val) {
266 const Field &field = val.getField();
267 ExpressionValue res;
268 res.i = Interval::Entire(field);
269 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
270 .Case<InvFeltOp>([&](auto) {
271 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
272 // To create this expression, we create a new symbol for Y and add the
273 // XY % prime = 1 constraint to the solver.
274 std::string symName = buildStringViaInsertionOp(*op);
275 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
276 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.bitWidth());
277 llvm::SMTExprRef prime = solver->mkBitvector(toAPSInt(field.prime()), field.bitWidth());
278 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
279 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
280 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
281 solver->addConstraint(constraint);
282 return invSym;
283 }).Default([](Operation *unsupported) {
284 llvm::report_fatal_error(
285 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
286 );
287 return nullptr;
288 });
289
290 return res;
291}
292
293void ExpressionValue::print(mlir::raw_ostream &os) const {
294 if (expr) {
295 expr->print(os);
296 } else {
297 os << "<null expression>";
298 }
299
300 os << " ( interval: " << i << " )";
301}
302
303/* IntervalAnalysisLattice */
304
305ChangeResult IntervalAnalysisLattice::join(const AbstractSparseLattice &other) {
306 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
307 if (!rhs) {
308 llvm::report_fatal_error("invalid join lattice type");
309 }
310 ChangeResult res = val.update(rhs->getValue());
311 for (auto &v : rhs->constraints) {
312 if (!constraints.contains(v)) {
313 constraints.insert(v);
314 res |= ChangeResult::Change;
315 }
316 }
317 return res;
318}
319
320ChangeResult IntervalAnalysisLattice::meet(const AbstractSparseLattice &other) {
321 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
322 if (!rhs) {
323 llvm::report_fatal_error("invalid join lattice type");
324 }
325 // Intersect the intervals
326 ExpressionValue lhsExpr = val.getScalarValue();
327 ExpressionValue rhsExpr = rhs->getValue().getScalarValue();
328 Interval newInterval = lhsExpr.getInterval().intersect(rhsExpr.getInterval());
329 ChangeResult res = setValue(lhsExpr.withInterval(newInterval));
330 for (auto &v : rhs->constraints) {
331 if (!constraints.contains(v)) {
332 constraints.insert(v);
333 res |= ChangeResult::Change;
334 }
335 }
336 return res;
337}
338
339void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
340 os << "IntervalAnalysisLattice { " << val << " }";
341}
342
344 if (val == newVal) {
345 return ChangeResult::NoChange;
346 }
347 val = newVal;
348 return ChangeResult::Change;
349}
350
352 LatticeValue newVal(e);
353 return setValue(newVal);
354}
355
357 if (!constraints.contains(e)) {
358 constraints.insert(e);
359 return ChangeResult::Change;
360 }
361 return ChangeResult::NoChange;
362}
363
364/* IntervalDataFlowAnalysis */
365
366const SourceRefLattice *
367IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) {
368 ProgramPoint *pp = _dataflowSolver.getProgramPointAfter(baseOp);
369 auto defaultSourceRefLattice = _dataflowSolver.lookupState<SourceRefLattice>(pp);
370 ensure(defaultSourceRefLattice, "failed to get lattice");
371 if (Operation *defOp = val.getDefiningOp()) {
372 ProgramPoint *defPoint = _dataflowSolver.getProgramPointAfter(defOp);
373 auto sourceRefLattice = _dataflowSolver.lookupState<SourceRefLattice>(defPoint);
374 ensure(sourceRefLattice, "failed to get SourceRefLattice for value");
375 return sourceRefLattice;
376 }
377 return defaultSourceRefLattice;
378}
379
381 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
382) {
383 // We only perform the visitation on operations within functions
384 FuncDefOp fn = op->getParentOfType<FuncDefOp>();
385 if (!fn) {
386 return success();
387 }
388
389 // If there are no operands or results, skip.
390 if (operands.empty() && results.empty()) {
391 return success();
392 }
393
394 // Get the values or defaults from the operand lattices
395 llvm::SmallVector<LatticeValue> operandVals;
396 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
397 for (unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
398 Value val = op->getOperand(opNum);
399 SourceRefLatticeValue refSet = getSourceRefLattice(op, val)->getOrDefault(val);
400 if (refSet.isSingleValue()) {
401 operandRefs.push_back(refSet.getSingleValue());
402 } else {
403 operandRefs.push_back(std::nullopt);
404 }
405 // First, lookup the operand value after it is initialized
406 auto priorState = operands[opNum]->getValue();
407 if (priorState.getScalarValue().getExpr() != nullptr) {
408 operandVals.push_back(priorState);
409 continue;
410 }
411
412 // Else, look up the stored value by `SourceRef`.
413 // We only care about scalar type values, so we ignore composite types, which
414 // are currently limited to non-Signal structs and arrays.
415 Type valTy = val.getType();
416 if (llvm::isa<ArrayType, StructType>(valTy) && !isSignalType(valTy)) {
417 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
418 operandVals.emplace_back(anyVal);
419 continue;
420 }
421
422 ensure(refSet.isScalar(), "should have ruled out array values already");
423
424 if (refSet.getScalarValue().empty()) {
425 // If we can't compute the reference, then there must be some unsupported
426 // op the reference analysis cannot handle. We emit a warning and return
427 // early, since there's no meaningful computation we can do for this op.
428 op->emitWarning()
429 .append(
430 "state of ", val, " is empty; defining operation is unsupported by SourceRef analysis"
431 )
432 .report();
433 // We still return success so we can return overapproximated and partial
434 // results to the user.
435 return success();
436 } else if (!refSet.isSingleValue()) {
437 std::string warning;
438 debug::Appender(warning) << "operand " << val << " is not a single value " << refSet
439 << ", overapproximating";
440 op->emitWarning(warning).report();
441 // Here, we will override the prior lattice value with a new symbol, representing
442 // "any" value, then use that value for the operands.
443 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
444 operandVals.emplace_back(anyVal);
445 } else {
446 const SourceRef &ref = refSet.getSingleValue();
447 // See if we've written the value before. If so, use that.
448 if (auto it = fieldWriteResults.find(ref); it != fieldWriteResults.end()) {
449 operandVals.emplace_back(it->second);
450 } else {
451 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
452 operandVals.emplace_back(exprVal);
453 }
454 }
455
456 // Since we initialized a value that was not found in the before lattice,
457 // update that value in the lattice so we can find it later, but we don't
458 // need to propagate the changes, since we already have what we need.
459 Lattice *operandLattice = getLatticeElement(val);
460 (void)operandLattice->setValue(operandVals[opNum]);
461 }
462
463 // Now, the way we update is dependent on the type of the operation.
464 if (isConstOp(op)) {
465 llvm::DynamicAPInt constVal = getConst(op);
466 llvm::SMTExprRef expr = createConstBitvectorExpr(constVal);
467 ExpressionValue latticeVal(field.get(), expr, constVal);
468 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
469 } else if (isArithmeticOp(op)) {
470 ExpressionValue result;
471 if (operands.size() == 2) {
472 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
473 } else {
474 result = performUnaryArithmetic(op, operandVals[0]);
475 }
476 // Also intersect with prior interval, if it's initialized
477 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
478 if (prior.getExpr()) {
479 result = result.withInterval(result.getInterval().intersect(prior.getInterval()));
480 }
481 propagateIfChanged(results[0], results[0]->setValue(result));
482 } else if (EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
483 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
484 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
485 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
486
487 // Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
488 // These patterns enforce that s is one of c0, ..., cN.
489 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
490 if (succeeded(res)) {
491 for (Value signalVal : res->first) {
492 applyInterval(emitEq, signalVal, res->second);
493 }
494 }
495
496 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
497 // Update the LHS and RHS to the same value, but restricted intervals
498 // based on the constraints.
499 const Interval &constrainInterval = constraint.getInterval();
500 applyInterval(emitEq, lhsVal, constrainInterval);
501 applyInterval(emitEq, rhsVal, constrainInterval);
502 } else if (auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
503 // assert enforces that the operand is true. So we apply an interval of [1, 1]
504 // to the operand.
505 Value cond = assertOp.getCondition();
506 applyInterval(assertOp, cond, Interval::True(field.get()));
507 // Also add the solver constraint that the expression must be true.
508 auto assertExpr = operandVals[0].getScalarValue();
509 // No need to propagate the constraint
510 (void)getLatticeElement(cond)->addSolverConstraint(assertExpr);
511 } else if (auto readf = llvm::dyn_cast<FieldReadOp>(op)) {
512 Value cmp = readf.getComponent();
513 if (isSignalType(cmp.getType())) {
514 // The reg value read from the signal type is equal to the value of the Signal
515 // struct overall.
516 propagateIfChanged(results[0], results[0]->setValue(operandVals[0]));
517 }
518 } else if (auto writef = llvm::dyn_cast<FieldWriteOp>(op)) {
519 // Update values stored in a field
520 ExpressionValue writeVal = operandVals[1].getScalarValue();
521 auto cmp = writef.getComponent();
522 // We also need to update the interval on the assigned symbol
523 SourceRefLatticeValue refSet = getSourceRefLattice(op, cmp)->getOrDefault(cmp);
524 if (refSet.isSingleValue()) {
525 auto fieldDefRes = writef.getFieldDefOp(tables);
526 if (succeeded(fieldDefRes)) {
527 SourceRefIndex idx(fieldDefRes.value());
528 SourceRef fieldRef = refSet.getSingleValue().createChild(idx);
529 llvm::SMTExprRef expr = getOrCreateSymbol(fieldRef);
530 ExpressionValue written(expr, writeVal.getInterval());
531
532 if (auto it = fieldWriteResults.find(fieldRef); it != fieldWriteResults.end()) {
533 const ExpressionValue &old = it->second;
534 Interval combinedWrite = old.getInterval().join(written.getInterval());
535 fieldWriteResults[fieldRef] = old.withInterval(combinedWrite);
536 } else {
537 fieldWriteResults[fieldRef] = written;
538 }
539
540 // Propagate to all field readers we've collected so far.
541 for (Lattice *readerLattice : fieldReadResults[fieldRef]) {
542 ExpressionValue prior = readerLattice->getValue().getScalarValue();
545 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
546 }
547 }
548 }
549 } else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
550 // Casts don't modify the intervals, but they do modify the SMT types.
551 ExpressionValue expr = operandVals[0].getScalarValue();
552 // We treat all ints and indexes as felts with the exception of comparison
553 // results, which are bools. So if `expr` is a bool, this cast needs to
554 // upcast to a felt.
555 if (expr.isBoolSort(smtSolver)) {
556 expr = boolToFelt(smtSolver, expr, field.get().bitWidth());
557 }
558 propagateIfChanged(results[0], results[0]->setValue(expr));
559 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
560 // Fetch the lattice for after the parent operation so we can propagate
561 // the yielded value to subsequent operations.
562 Operation *parent = op->getParentOp();
563 ensure(parent, "yield operation must have parent operation");
564 // Bind the operand values to the result values of the parent
565 for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
566 Value parentRes = parent->getResult(idx);
567 Lattice *resLattice = getLatticeElement(parentRes);
568 // Merge with the existing value, if present (e.g., another branch)
569 // has possible value that must be merged.
570 ExpressionValue exprVal = resLattice->getValue().getScalarValue();
571 ExpressionValue newResVal = operandVals[idx].getScalarValue();
572 if (exprVal.getExpr() != nullptr) {
573 newResVal = exprVal.withInterval(exprVal.getInterval().join(newResVal.getInterval()));
574 } else {
575 newResVal = ExpressionValue(createFeltSymbol(parentRes), newResVal.getInterval());
576 }
577 propagateIfChanged(resLattice, resLattice->setValue(newResVal));
578 }
579 } else if (
580 // We do not need to explicitly handle read ops since they are resolved at the operand value
581 // step where `SourceRef`s are queries (with the exception of the Signal struct, see above).
582 !isReadOp(op)
583 // We do not currently handle return ops as the analysis is currently limited to constrain
584 // functions, which return no value.
585 && !isReturnOp(op)
586 // The analysis ignores definition ops.
587 && !isDefinitionOp(op)
588 // We do not need to analyze the creation of structs.
589 && !llvm::isa<CreateStructOp>(op)
590 ) {
591 op->emitWarning("unhandled operation, analysis may be incomplete").report();
592 }
593
594 return success();
595}
596
598 auto it = refSymbols.find(r);
599 if (it != refSymbols.end()) {
600 return it->second;
601 }
602 llvm::SMTExprRef sym = createFeltSymbol(r);
603 refSymbols[r] = sym;
604 return sym;
605}
606
607llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const SourceRef &r) const {
608 return createFeltSymbol(buildStringViaPrint(r).c_str());
609}
610
611llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v) const {
612 return createFeltSymbol(buildStringViaPrint(v).c_str());
613}
614
615llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) const {
616 return field.get().createSymbol(smtSolver, name);
617}
618
619llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
620 ensure(isConstOp(op), "op is not a const op");
621
622 llvm::DynamicAPInt fieldConst =
623 TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
624 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
625 llvm::APSInt constOpVal(feltConst.getValue());
626 return field.get().reduce(constOpVal);
627 })
628 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
629 return DynamicAPInt(indexConst.value());
630 })
631 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
632 return DynamicAPInt(intConst.value());
633 }).Default([](Operation *illegalOp) {
634 std::string err;
635 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
636 llvm::report_fatal_error(Twine(err));
637 return llvm::DynamicAPInt();
638 });
639 return fieldConst;
640}
641
642ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
643 Operation *op, const LatticeValue &a, const LatticeValue &b
644) {
645 ensure(isArithmeticOp(op), "is not arithmetic op");
646
647 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
648 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
649 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
650
651 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
652 .Case<AddFeltOp>([&](auto _) { return add(smtSolver, lhs, rhs); })
653 .Case<SubFeltOp>([&](auto _) { return sub(smtSolver, lhs, rhs); })
654 .Case<MulFeltOp>([&](auto _) { return mul(smtSolver, lhs, rhs); })
655 .Case<DivFeltOp>([&](auto divOp) { return div(smtSolver, divOp, lhs, rhs); })
656 .Case<ModFeltOp>([&](auto _) { return mod(smtSolver, lhs, rhs); })
657 .Case<AndFeltOp>([&](auto _) { return bitAnd(smtSolver, lhs, rhs); })
658 .Case<ShlFeltOp>([&](auto _) { return shiftLeft(smtSolver, lhs, rhs); })
659 .Case<ShrFeltOp>([&](auto _) { return shiftRight(smtSolver, lhs, rhs); })
660 .Case<CmpOp>([&](auto cmpOp) { return cmp(smtSolver, cmpOp, lhs, rhs); })
661 .Case<AndBoolOp>([&](auto _) { return boolAnd(smtSolver, lhs, rhs); })
662 .Case<OrBoolOp>([&](auto _) { return boolOr(smtSolver, lhs, rhs); })
663 .Case<XorBoolOp>([&](auto _) {
664 return boolXor(smtSolver, lhs, rhs);
665 }).Default([&](auto *unsupported) {
666 unsupported
667 ->emitWarning(
668 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
669 )
670 .report();
671 return fallbackBinaryOp(smtSolver, unsupported, lhs, rhs);
672 });
673
674 ensure(res.getExpr(), "arithmetic produced null smt expr");
675 return res;
676}
677
679IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
680 ensure(isArithmeticOp(op), "is not arithmetic op");
681
682 auto val = a.getScalarValue();
683 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
684
685 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
686 .Case<NegFeltOp>([&](auto _) { return neg(smtSolver, val); })
687 .Case<NotFeltOp>([&](auto _) { return notOp(smtSolver, val); })
688 .Case<NotBoolOp>([&](auto _) { return boolNot(smtSolver, val); })
689 // The inverse op is currently overapproximated
690 .Case<InvFeltOp>([&](auto inv) {
691 return fallbackUnaryOp(smtSolver, inv, val);
692 }).Default([&](auto *unsupported) {
693 unsupported
694 ->emitWarning(
695 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
696 )
697 .report();
698 return fallbackUnaryOp(smtSolver, unsupported, val);
699 });
700
701 ensure(res.getExpr(), "arithmetic produced null smt expr");
702 return res;
703}
704
705void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Interval newInterval) {
706 Lattice *valLattice = getLatticeElement(val);
707 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
708 // Intersect with the current value to accumulate restrictions across constraints.
709 Interval intersection = oldLatticeVal.getInterval().intersect(newInterval);
710 ExpressionValue newLatticeVal = oldLatticeVal.withInterval(intersection);
711 ChangeResult changed = valLattice->setValue(newLatticeVal);
712
713 if (auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
714 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
715
716 // Apply the interval from the constrain function inputs to the compute function inputs
717 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
718 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
719 auto structOp = fnOp->getParentOfType<StructDefOp>();
720 FuncDefOp computeFn = structOp.getComputeFuncOp();
721 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
722 Lattice *computeEntryLattice = getLatticeElement(computeArg);
723
724 SourceRef ref(computeArg);
725 ExpressionValue newArgVal(getOrCreateSymbol(ref), newInterval);
726 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
727 }
728 }
729
730 // Now we descend into val's operands, if it has any.
731 Operation *definingOp = val.getDefiningOp();
732 if (!definingOp) {
733 propagateIfChanged(valLattice, changed);
734 return;
735 }
736
737 const Field &f = field.get();
738
739 // This is a rules-based operation. If we have a rule for a given operation,
740 // then we can make some kind of update, otherwise we leave the intervals
741 // as is.
742 // - First we'll define all the rules so the type switch can be less messy
743
744 // cmp.<pred> restricts each side of the comparison if the result is known.
745 auto cmpCase = [&](CmpOp cmpOp) {
746 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
747 // either "true" (1) or "false" (0).
748 // -- In the case of a contradictory circuit, however, the cmp result is allowed
749 // to be empty.
750 ensure(
751 newInterval.isBoolean() || newInterval.isEmpty(),
752 "new interval for CmpOp is not boolean or empty"
753 );
754 if (!newInterval.isDegenerate()) {
755 // The comparison result is unknown, so we can't update the operand ranges
756 return;
757 }
758
759 bool cmpTrue = newInterval.rhs() == f.one();
760
761 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
762 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
763 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
764 rhsExpr = rhsLat->getValue().getScalarValue();
765
766 Interval newLhsInterval, newRhsInterval;
767 const Interval &lhsInterval = lhsExpr.getInterval();
768 const Interval &rhsInterval = rhsExpr.getInterval();
769
770 FeltCmpPredicate pred = cmpOp.getPredicate();
771 // predicate cases
772 auto eqCase = [&]() {
773 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
774 (pred == FeltCmpPredicate::NE && !cmpTrue);
775 };
776 auto neCase = [&]() {
777 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
778 (pred == FeltCmpPredicate::EQ && !cmpTrue);
779 };
780 auto ltCase = [&]() {
781 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
782 (pred == FeltCmpPredicate::GE && !cmpTrue);
783 };
784 auto leCase = [&]() {
785 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
786 (pred == FeltCmpPredicate::GT && !cmpTrue);
787 };
788 auto gtCase = [&]() {
789 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
790 (pred == FeltCmpPredicate::LE && !cmpTrue);
791 };
792 auto geCase = [&]() {
793 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
794 (pred == FeltCmpPredicate::LT && !cmpTrue);
795 };
796
797 // new intervals based on case
798 if (eqCase()) {
799 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
800 } else if (neCase()) {
801 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
802 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
803 // an empty value range.
804 newLhsInterval = newRhsInterval = Interval::Empty(f);
805 } else if (lhsInterval.isDegenerate()) {
806 // rhs must not overlap with lhs
807 newLhsInterval = lhsInterval;
808 newRhsInterval = rhsInterval.difference(lhsInterval);
809 } else if (rhsInterval.isDegenerate()) {
810 // lhs must not overlap with rhs
811 newLhsInterval = lhsInterval.difference(rhsInterval);
812 newRhsInterval = rhsInterval;
813 } else {
814 // Leave unchanged
815 newLhsInterval = lhsInterval;
816 newRhsInterval = rhsInterval;
817 }
818 } else if (ltCase()) {
819 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
820 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
821 } else if (leCase()) {
822 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
823 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
824 } else if (gtCase()) {
825 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
826 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
827 } else if (geCase()) {
828 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
829 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
830 } else {
831 cmpOp->emitWarning("unhandled cmp predicate").report();
832 return;
833 }
834
835 // Now we recurse to each operand
836 applyInterval(cmpOp, lhs, newLhsInterval);
837 applyInterval(cmpOp, rhs, newRhsInterval);
838 };
839
840 // Multiplication cases:
841 // - If the result of a multiplication is non-zero, then both operands must be
842 // non-zero.
843 // - If one operand is a constant, we can propagate the new interval when multiplied
844 // by the multiplicative inverse of the constant.
845 auto mulCase = [&](MulFeltOp mulOp) {
846 // We check for the constant case first.
847 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
848 auto latVal = getLatticeElement(multiplicand)->getValue().getScalarValue();
849 APInt constVal = constOperand.getValue();
850 if (constVal.isZero()) {
851 // There's no inverse for zero, so we do nothing.
852 return;
853 }
854 Interval updatedInterval = newInterval * Interval::Degenerate(f, f.inv(constVal));
855 applyInterval(mulOp, multiplicand, updatedInterval);
856 };
857
858 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
859
860 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
861 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
862 // If both are consts, we don't need to do anything
863 if (lhsConstOp && rhsConstOp) {
864 return;
865 } else if (lhsConstOp) {
866 constCase(lhsConstOp, rhs);
867 return;
868 } else if (rhsConstOp) {
869 constCase(rhsConstOp, lhs);
870 return;
871 }
872
873 // Otherwise, try to propagate non-zero information.
874 auto zeroInt = Interval::Degenerate(f, f.zero());
875 if (newInterval.intersect(zeroInt).isNotEmpty()) {
876 // The multiplication may be zero, so we can't reduce the operands to be non-zero
877 return;
878 }
879
880 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
881 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
882 rhsExpr = rhsLat->getValue().getScalarValue();
883 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
884 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
885 applyInterval(mulOp, lhs, newLhsInterval);
886 applyInterval(mulOp, rhs, newRhsInterval);
887 };
888
889 auto addCase = [&](AddFeltOp addOp) {
890 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
891 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
892 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
893 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
894
895 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
896
897 Interval derivedLhsInt = newInterval - currRhsInt;
898 Interval derivedRhsInt = newInterval - currLhsInt;
899
900 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
901 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
902
903 applyInterval(addOp, lhs, finalLhsInt);
904 applyInterval(addOp, rhs, finalRhsInt);
905 };
906
907 auto subCase = [&](SubFeltOp subOp) {
908 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
909 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
910 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
911 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
912
913 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
914
915 Interval derivedLhsInt = newInterval + currRhsInt;
916 Interval derivedRhsInt = currLhsInt - newInterval;
917
918 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
919 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
920
921 applyInterval(subOp, lhs, finalLhsInt);
922 applyInterval(subOp, rhs, finalRhsInt);
923 };
924
925 auto readfCase = [&](FieldReadOp readfOp) {
926 const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val);
927 SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val);
928
929 if (sourceRefVal.isSingleValue()) {
930 const SourceRef &ref = sourceRefVal.getSingleValue();
931 fieldReadResults[ref].insert(valLattice);
932
933 // Also propagate to all other field read results for this field
934 for (Lattice *l : fieldReadResults[ref]) {
935 if (l != valLattice) {
936 propagateIfChanged(l, l->setValue(newLatticeVal));
937 }
938 }
939 }
940
941 // We have a special case for the Signal struct: if this value is created
942 // from reading a Signal struct's reg field, we also apply the interval to
943 // the struct itself.
944 Value comp = readfOp.getComponent();
945 if (isSignalType(comp.getType())) {
946 applyInterval(readfOp, comp, newInterval);
947 }
948 };
949
950 auto readArrCase = [&](ReadArrayOp _) {
951 const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val);
952 SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val);
953
954 if (sourceRefVal.isSingleValue()) {
955 const SourceRef &ref = sourceRefVal.getSingleValue();
956 fieldReadResults[ref].insert(valLattice);
957
958 // Also propagate to all other field read results for this field
959 for (Lattice *l : fieldReadResults[ref]) {
960 if (l != valLattice) {
961 propagateIfChanged(l, l->setValue(newLatticeVal));
962 }
963 }
964 }
965 };
966
967 // For casts, just pass the interval along to the cast's operand.
968 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
969
970 // - Apply the rules given the op.
971 // NOTE: disabling clang-format for this because it makes the last case statement
972 // look ugly.
973 // clang-format off
974 TypeSwitch<Operation *>(definingOp)
975 .Case<CmpOp>([&](auto op) { cmpCase(op); })
976 .Case<AddFeltOp>([&](auto op) { return addCase(op); })
977 .Case<SubFeltOp>([&](auto op) { return subCase(op); })
978 .Case<MulFeltOp>([&](auto op) { mulCase(op); })
979 .Case<FieldReadOp>([&](auto op){ readfCase(op); })
980 .Case<ReadArrayOp>([&](auto op){ readArrCase(op); })
981 .Case<IntToFeltOp, FeltToIndexOp>([&](auto op) { castCase(op); })
982 .Default([&](Operation *) { });
983 // clang-format on
984
985 // Propagate after recursion to avoid having recursive calls unset the value.
986 propagateIfChanged(valLattice, changed);
987}
988
989FailureOr<std::pair<DenseSet<Value>, Interval>>
990IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
991 auto isZeroConst = [this](Value v) {
992 Operation *op = v.getDefiningOp();
993 if (!op) {
994 return false;
995 }
996 if (!isConstOp(op)) {
997 return false;
998 }
999 return getConst(op) == field.get().zero();
1000 };
1001 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1002 Value exprTree = nullptr;
1003 if (lhsIsZero && !rhsIsZero) {
1004 exprTree = rhs;
1005 } else if (!lhsIsZero && rhsIsZero) {
1006 exprTree = lhs;
1007 } else {
1008 return failure();
1009 }
1010
1011 // We now explore the expression tree for multiplications of subtractions/signal values.
1012 std::optional<SourceRef> signalRef = std::nullopt;
1013 DenseSet<Value> signalVals;
1014 SmallVector<DynamicAPInt> consts;
1015 SmallVector<Value> frontier {exprTree};
1016 while (!frontier.empty()) {
1017 Value v = frontier.back();
1018 frontier.pop_back();
1019 Operation *op = v.getDefiningOp();
1020
1021 FeltConstantOp c;
1022 Value signalVal;
1023 auto handleRefValue = [this, &baseOp, &signalRef, &signalVal, &signalVals]() {
1024 SourceRefLatticeValue refSet =
1025 getSourceRefLattice(baseOp, signalVal)->getOrDefault(signalVal);
1026 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1027 return failure();
1028 }
1029 SourceRef r = refSet.getSingleValue();
1030 if (signalRef.has_value() && signalRef.value() != r) {
1031 return failure();
1032 } else if (!signalRef.has_value()) {
1033 signalRef = r;
1034 }
1035 signalVals.insert(signalVal);
1036 return success();
1037 };
1038
1039 auto subPattern = m_CommutativeOp<SubFeltOp>(m_RefValue(&signalVal), m_Constant(&c));
1040 if (op && matchPattern(op, subPattern)) {
1041 if (failed(handleRefValue())) {
1042 return failure();
1043 }
1044 auto constInt = APSInt(c.getValue());
1045 consts.push_back(field.get().reduce(constInt));
1046 continue;
1047 } else if (m_RefValue(&signalVal).match(v)) {
1048 if (failed(handleRefValue())) {
1049 return failure();
1050 }
1051 consts.push_back(field.get().zero());
1052 continue;
1053 }
1054
1055 Value a, b;
1056 auto mulPattern = m_CommutativeOp<MulFeltOp>(matchers::m_Any(&a), matchers::m_Any(&b));
1057 if (op && matchPattern(op, mulPattern)) {
1058 frontier.push_back(a);
1059 frontier.push_back(b);
1060 continue;
1061 }
1062
1063 return failure();
1064 }
1065
1066 // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4),
1067 // we will create a larger range of [0, 4], since we don't support multiple intervals.
1068 std::sort(consts.begin(), consts.end());
1069 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1070 return std::make_pair(std::move(signalVals), iv);
1071}
1072
1073/* StructIntervals */
1074
1076 mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx
1077) {
1078
1079 auto validSourceRefType = [](const SourceRef &ref) {
1080 // We only want to compute intervals for field elements and not composite types,
1081 // with the exception of the Signal struct.
1082 if (!ref.isScalar() && !ref.isSignal()) {
1083 return false;
1084 }
1085 // We also don't want to show the interval for a Signal and its internal reg.
1086 if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1087 return false;
1088 }
1089 return true;
1090 };
1091
1092 auto computeIntervalsImpl = [&solver, &ctx, &validSourceRefType, this](
1093 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &fieldRanges,
1094 llvm::SetVector<ExpressionValue> &solverConstraints
1095 ) {
1096 // Since every lattice value does not contain every value, we will traverse
1097 // the function backwards (from most up-to-date to least-up-to-date lattices)
1098 // searching for the source refs. Once a source ref is found, we remove it
1099 // from the search set.
1100
1101 SourceRefSet searchSet;
1102 for (const auto &ref : SourceRef::getAllSourceRefs(structDef, fn)) {
1103 // We only want to compute intervals for field elements and not composite types,
1104 // with the exception of the Signal struct.
1105 if (validSourceRefType(ref)) {
1106 searchSet.insert(ref);
1107 }
1108 }
1109
1110 // Iterate over arguments
1111 for (BlockArgument arg : fn.getArguments()) {
1112 SourceRef ref {arg};
1113 if (searchSet.erase(ref)) {
1114 const IntervalAnalysisLattice *lattice = solver.lookupState<IntervalAnalysisLattice>(arg);
1115 // If we never referenced this argument, use a default value
1116 ExpressionValue expr = lattice->getValue().getScalarValue();
1117 if (!expr.getExpr()) {
1118 expr = expr.withInterval(Interval::Entire(ctx.getField()));
1119 }
1120 fieldRanges[ref] = expr.getInterval();
1121 assert(fieldRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1122 }
1123 }
1124
1125 // Iterate over fields that were touched by the analysis
1126 for (const auto &[ref, lattices] : ctx.intervalDFA->getFieldReadResults()) {
1127 // All lattices should have the same value, so we can get the front.
1128 if (!lattices.empty() && searchSet.erase(ref)) {
1129 const IntervalAnalysisLattice *lattice = *lattices.begin();
1130 fieldRanges[ref] = lattice->getValue().getScalarValue().getInterval();
1131 assert(fieldRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1132 }
1133 }
1134
1135 for (const auto &[ref, val] : ctx.intervalDFA->getFieldWriteResults()) {
1136 if (searchSet.erase(ref)) {
1137 fieldRanges[ref] = val.getInterval();
1138 assert(fieldRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1139 }
1140 }
1141
1142 // For all unfound refs, default to the entire range.
1143 for (const auto &ref : searchSet) {
1144 fieldRanges[ref] = Interval::Entire(ctx.getField());
1145 }
1146
1147 // Sort the outputs since we assembled things out of order.
1148 llvm::sort(fieldRanges, [](auto a, auto b) { return std::get<0>(a) < std::get<0>(b); });
1149 };
1150
1151 computeIntervalsImpl(structDef.getComputeFuncOp(), computeFieldRanges, computeSolverConstraints);
1152 computeIntervalsImpl(
1153 structDef.getConstrainFuncOp(), constrainFieldRanges, constrainSolverConstraints
1154 );
1155
1156 return success();
1157}
1158
1159void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints, bool printCompute) const {
1160 auto writeIntervals =
1161 [&os, &withConstraints](
1162 const char *fnName, const llvm::MapVector<SourceRef, Interval> &fieldRanges,
1163 const llvm::SetVector<ExpressionValue> &solverConstraints, bool printName
1164 ) {
1165 int indent = 4;
1166 if (printName) {
1167 os << '\n';
1168 os.indent(indent) << fnName << " {";
1169 indent += 4;
1170 }
1171
1172 if (fieldRanges.empty()) {
1173 os << "}\n";
1174 return;
1175 }
1176
1177 for (auto &[ref, interval] : fieldRanges) {
1178 os << '\n';
1179 os.indent(indent) << ref << " in " << interval;
1180 }
1181
1182 if (withConstraints) {
1183 os << "\n\n";
1184 os.indent(indent) << "Solver Constraints { ";
1185 if (solverConstraints.empty()) {
1186 os << "}\n";
1187 } else {
1188 for (const auto &e : solverConstraints) {
1189 os << '\n';
1190 os.indent(indent + 4);
1191 e.getExpr()->print(os);
1192 }
1193 os << '\n';
1194 os.indent(indent) << '}';
1195 }
1196 }
1197
1198 if (printName) {
1199 os << '\n';
1200 os.indent(indent - 4) << '}';
1201 }
1202 };
1203
1204 os << "StructIntervals { ";
1205 if (constrainFieldRanges.empty() && (!printCompute || computeFieldRanges.empty())) {
1206 os << "}\n";
1207 return;
1208 }
1209
1210 if (printCompute) {
1211 writeIntervals(FUNC_NAME_COMPUTE, computeFieldRanges, computeSolverConstraints, printCompute);
1212 }
1213 writeIntervals(
1214 FUNC_NAME_CONSTRAIN, constrainFieldRanges, constrainSolverConstraints, printCompute
1215 );
1216
1217 os << "\n}\n";
1218}
1219
1220} // namespace llzk
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
bool isBoolSort(llvm::SMTSolverRef solver) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:27
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
Definition Field.h:71
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:39
unsigned bitWidth() const
Definition Field.h:68
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getFieldReadResults() const
const llvm::DenseMap< SourceRef, ExpressionValue > & getFieldWriteResults() const
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:304
static Interval True(const Field &f)
Definition Intervals.h:219
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
static Interval Entire(const Field &f)
Definition Intervals.h:223
bool isDegenerate() const
Definition Intervals.h:306
static Interval False(const Field &f)
Definition Intervals.h:217
Interval join(const Interval &rhs) const
Union.
Defines an index into an LLZK object.
Definition SourceRef.h:36
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
A lattice for use in dense analysis.
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
SourceRef createChild(SourceRefIndex r) const
Definition SourceRef.h:252
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:60
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:83
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:75
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:20
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:68
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:601
std::variant< ScalarTy, ArrayTy > & getValue()
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
auto m_RefValue()
Definition Matchers.h:68
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_Constant()
Definition Matchers.h:88
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
Definition Matchers.h:47
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA