LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
12#include "llzk/Util/Debug.h"
14
15#include <mlir/Dialect/SCF/IR/SCF.h>
16
17#include <llvm/ADT/TypeSwitch.h>
18
19using namespace mlir;
20
21namespace llzk {
22
23using namespace array;
24using namespace boolean;
25using namespace cast;
26using namespace component;
27using namespace constrain;
28using namespace felt;
29using namespace function;
30
31/* ExpressionValue */
32
34 if (expr == nullptr && rhs.expr == nullptr) {
35 return i == rhs.i;
36 }
37 if (expr == nullptr || rhs.expr == nullptr) {
38 return false;
39 }
40 return i == rhs.i && *expr == *rhs.expr;
41}
42
44intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
45 Interval res = lhs.i.intersect(rhs.i);
46 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
47 return ExpressionValue(exprEq, res);
48}
49
51add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
53 res.i = lhs.i + rhs.i;
54 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
55 return res;
56}
57
59sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
61 res.i = lhs.i - rhs.i;
62 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
63 return res;
64}
65
67mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
69 res.i = lhs.i * rhs.i;
70 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
71 return res;
72}
73
75div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs,
76 const ExpressionValue &rhs) {
78 auto divRes = lhs.i / rhs.i;
79 if (failed(divRes)) {
80 op->emitWarning(
81 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
82 " Range of division result will be treated as unbounded."
83 );
84 res.i = Interval::Entire(lhs.getField());
85 } else {
86 res.i = *divRes;
87 }
88 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
89 return res;
90}
91
93mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
95 res.i = lhs.i % rhs.i;
96 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
97 return res;
98}
99
101bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
102 ExpressionValue res;
103 res.i = lhs.i & rhs.i;
104 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
105 return res;
106}
107
109shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
110 ExpressionValue res;
111 res.i = lhs.i << rhs.i;
112 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
113 return res;
114}
115
117shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
118 ExpressionValue res;
119 res.i = lhs.i >> rhs.i;
120 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
121 return res;
122}
123
125cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs) {
126 ExpressionValue res;
127 const Field &f = lhs.getField();
128 // Default result is any boolean output for when we are unsure about the comparison result.
129 res.i = Interval::Boolean(f);
130 switch (op.getPredicate()) {
131 case FeltCmpPredicate::EQ:
132 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
133 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
134 res.i = lhs.i == rhs.i ? Interval::True(f) : Interval::False(f);
135 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
136 res.i = Interval::False(f);
137 }
138 break;
139 case FeltCmpPredicate::NE:
140 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
141 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
142 res.i = lhs.i != rhs.i ? Interval::True(f) : Interval::False(f);
143 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
144 res.i = Interval::True(f);
145 }
146 break;
147 case FeltCmpPredicate::LT:
148 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
149 break;
150 case FeltCmpPredicate::LE:
151 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
152 break;
153 case FeltCmpPredicate::GT:
154 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
155 break;
156 case FeltCmpPredicate::GE:
157 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
158 break;
159 }
160 return res;
161}
162
164boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
165 ExpressionValue res;
166 res.i = boolAnd(lhs.i, rhs.i);
167 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
168 return res;
169}
170
172boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
173 ExpressionValue res;
174 res.i = boolOr(lhs.i, rhs.i);
175 res.expr = solver->mkOr(lhs.expr, rhs.expr);
176 return res;
177}
178
180boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
181 ExpressionValue res;
182 res.i = boolXor(lhs.i, rhs.i);
183 // There's no Xor, so we do (L || R) && !(L && R)
184 res.expr = solver->mkAnd(
185 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
186 );
187 return res;
188}
189
191 llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs
192) {
193 ExpressionValue res;
194 res.i = Interval::Entire(lhs.getField());
195 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
196 .Case<OrFeltOp>([&](auto _) { return solver->mkBVOr(lhs.expr, rhs.expr); })
197 .Case<XorFeltOp>([&](auto _) {
198 return solver->mkBVXor(lhs.expr, rhs.expr);
199 }).Default([&](auto *unsupported) {
200 llvm::report_fatal_error(
201 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
202 );
203 return nullptr;
204 });
205
206 return res;
207}
208
209ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val) {
210 ExpressionValue res;
211 res.i = -val.i;
212 res.expr = solver->mkBVNeg(val.expr);
213 return res;
214}
215
216ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val) {
217 ExpressionValue res;
218 res.i = ~val.i;
219 res.expr = solver->mkBVNot(val.expr);
220 return res;
221}
222
223ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val) {
224 ExpressionValue res;
225 res.i = boolNot(val.i);
226 res.expr = solver->mkNot(val.expr);
227 return res;
228}
229
231fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val) {
232 const Field &field = val.getField();
233 ExpressionValue res;
234 res.i = Interval::Entire(field);
235 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
236 .Case<InvFeltOp>([&](InvFeltOp _) {
237 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
238 // To create this expression, we create a new symbol for Y and add the
239 // XY % prime = 1 constraint to the solver.
240 std::string symName = buildStringViaInsertionOp(*op);
241 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
242 llvm::SMTExprRef one = solver->mkBitvector(field.one(), field.bitWidth());
243 llvm::SMTExprRef prime = solver->mkBitvector(field.prime(), field.bitWidth());
244 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
245 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
246 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
247 solver->addConstraint(constraint);
248 return invSym;
249 }).Default([&](Operation *unsupported) {
250 llvm::report_fatal_error(
251 "no fallback provided for " + mlir::Twine(op->getName().getStringRef())
252 );
253 return nullptr;
254 });
255
256 return res;
257}
258
259void ExpressionValue::print(mlir::raw_ostream &os) const {
260 if (expr) {
261 expr->print(os);
262 } else {
263 os << "<null expression>";
264 }
265
266 os << " ( interval: " << i << " )";
267}
268
269/* IntervalAnalysisLattice */
270
271ChangeResult IntervalAnalysisLattice::join(const AbstractDenseLattice &other) {
272 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
273 if (!rhs) {
274 llvm::report_fatal_error("invalid join lattice type");
275 }
276 ChangeResult res = ChangeResult::NoChange;
277 for (auto &[k, v] : rhs->valMap) {
278 auto it = valMap.find(k);
279 if (it == valMap.end() || it->second != v) {
280 valMap[k] = v;
281 res |= ChangeResult::Change;
282 }
283 }
284 for (auto &v : rhs->constraints) {
285 if (!constraints.contains(v)) {
286 constraints.insert(v);
287 res |= ChangeResult::Change;
288 }
289 }
290 for (auto &[e, i] : rhs->intervals) {
291 auto it = intervals.find(e);
292 if (it == intervals.end() || it->second != i) {
293 intervals[e] = i;
294 res |= ChangeResult::Change;
295 }
296 }
297 return res;
298}
299
300void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
301 os << "IntervalAnalysisLattice { ";
302 for (auto &[ref, val] : valMap) {
303 os << "\n (valMap) " << ref << " := " << val;
304 }
305 for (auto &[expr, interval] : intervals) {
306 os << "\n (intervals) ";
307 if (!expr) {
308 os << "<null expr>";
309 } else {
310 expr->print(os);
311 }
312 os << " in " << interval;
313 }
314 if (!valMap.empty()) {
315 os << '\n';
316 }
317 os << '}';
318}
319
320FailureOr<IntervalAnalysisLattice::LatticeValue> IntervalAnalysisLattice::getValue(Value v) const {
321 auto it = valMap.find(v);
322 if (it == valMap.end()) {
323 return failure();
324 }
325 return it->second;
326}
327
328FailureOr<IntervalAnalysisLattice::LatticeValue>
329IntervalAnalysisLattice::getValue(Value v, StringAttr f) const {
330 auto it = fieldMap.find(v);
331 if (it == fieldMap.end()) {
332 return failure();
333 }
334 auto fit = it->second.find(f);
335 if (fit == it->second.end()) {
336 return failure();
337 }
338 return fit->second;
339}
340
341ChangeResult IntervalAnalysisLattice::setValue(Value v, ExpressionValue e) {
342 LatticeValue val(e);
343 if (valMap[v] == val) {
344 return ChangeResult::NoChange;
345 }
346 valMap[v] = val;
347 intervals[e.getExpr()] = e.getInterval();
348 return ChangeResult::Change;
349}
350
351ChangeResult IntervalAnalysisLattice::setValue(Value v, StringAttr f, ExpressionValue e) {
352 LatticeValue val(e);
353 if (fieldMap[v][f] == val) {
354 return ChangeResult::NoChange;
355 }
356 fieldMap[v][f] = val;
357 intervals[e.getExpr()] = e.getInterval();
358 return ChangeResult::Change;
359}
360
362 if (!constraints.contains(e)) {
363 constraints.insert(e);
364 return ChangeResult::Change;
365 }
366 return ChangeResult::NoChange;
367}
368
369FailureOr<Interval> IntervalAnalysisLattice::findInterval(llvm::SMTExprRef expr) const {
370 auto it = intervals.find(expr);
371 if (it != intervals.end()) {
372 return it->second;
373 }
374 return failure();
375}
376
377ChangeResult IntervalAnalysisLattice::setInterval(llvm::SMTExprRef expr, Interval i) {
378 auto it = intervals.find(expr);
379 if (it != intervals.end() && it->second == i) {
380 return ChangeResult::NoChange;
381 }
382 intervals[expr] = i;
383 return ChangeResult::Change;
384}
385
386/* IntervalDataFlowAnalysis */
387
392 CallOpInterface call, dataflow::CallControlFlowAction action,
394) {
398 if (action == dataflow::CallControlFlowAction::EnterCallee) {
399 // We skip updating the incoming lattice for function calls,
400 // as values are relative to the containing function/struct, so we don't need to pollute
401 // the callee with the callers values.
402 setToEntryState(after);
403 }
407 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
408 // Get the argument values of the lattice by getting the state as it would
409 // have been for the callsite.
410 dataflow::AbstractDenseLattice *beforeCall = nullptr;
411 if (auto *prev = call->getPrevNode()) {
412 beforeCall = getLattice(prev);
413 } else {
414 beforeCall = getLattice(call->getBlock());
415 }
416 ensure(beforeCall, "could not get prior lattice");
417
418 // The lattice at the return is the lattice before the call
419 propagateIfChanged(after, after->join(*beforeCall));
420 }
425 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
426 // For external calls, we propagate what information we already have from
427 // before the call to after the call, since the external call won't invalidate
428 // any of that information. It also, conservatively, makes no assumptions about
429 // external calls and their computation, so CDG edges will not be computed over
430 // input arguments to external functions.
431 join(after, before);
432 }
433}
434
436 Operation *op, const Lattice &before, Lattice *after
437) {
438 ChangeResult changed = after->join(before);
439
440 llvm::SmallVector<LatticeValue> operandVals;
441
442 auto constrainRefLattice = _dataflowSolver.lookupState<ConstrainRefLattice>(op);
443 ensure(constrainRefLattice, "failed to get lattice");
444
445 for (OpOperand &operand : op->getOpOperands()) {
446 Value val = operand.get();
447 // First, lookup the operand value in the before state.
448 auto priorState = before.getValue(val);
449 if (succeeded(priorState) && priorState->getScalarValue().getExpr() != nullptr) {
450 operandVals.push_back(*priorState);
451 continue;
452 }
453
454 // Else, look up the stored value by constrain ref.
455 // We only care about scalar type values, so we ignore composite types, which
456 // are currently limited to non-Signal structs and arrays.
457 Type valTy = val.getType();
458 if (mlir::isa<ArrayType, StructType>(valTy) && !isSignalType(valTy)) {
459 operandVals.push_back(LatticeValue());
460 continue;
461 }
462
463 ConstrainRefLatticeValue refSet = constrainRefLattice->getOrDefault(val);
464 ensure(refSet.isScalar(), "should have ruled out array values already");
465
466 if (refSet.getScalarValue().empty()) {
467 // If we can't compute the reference, then there must be some unsupported
468 // op the reference analysis cannot handle. We emit a warning and return
469 // early, since there's no meaningful computation we can do for this op.
470 op->emitWarning() << "state of " << val
471 << " is empty; defining operation is unsupported by constrain ref analysis";
472 propagateIfChanged(after, changed);
473 return;
474 } else if (!refSet.isSingleValue()) {
475 std::string warning;
476 debug::Appender(warning) << "operand " << val << " is not a single value " << refSet
477 << ", overapproximating";
478 op->emitWarning(warning);
479 // Here, we will override the prior lattice value with a new symbol, representing
480 // "any" value, then use that value for the operands.
481 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
482 changed |= after->setValue(val, anyVal);
483 operandVals.emplace_back(anyVal);
484 } else {
485 auto ref = refSet.getSingleValue();
486 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
487 if (succeeded(priorState)) {
488 exprVal = exprVal.withInterval(priorState->getScalarValue().getInterval());
489 }
490 changed |= after->setValue(val, exprVal);
491 operandVals.emplace_back(exprVal);
492 }
493 }
494
495 // Now, the way we update is dependent on the type of the operation.
496 if (isConstOp(op)) {
497 auto constVal = getConst(op);
498 auto expr = createConstBitvectorExpr(constVal);
499 ExpressionValue latticeVal(field.get(), expr, constVal);
500 changed |= after->setValue(op->getResult(0), latticeVal);
501 } else if (isArithmeticOp(op)) {
502 ensure(operandVals.size() <= 2, "arithmetic op with the wrong number of operands");
503 ExpressionValue result;
504 if (operandVals.size() == 2) {
505 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
506 } else {
507 result = performUnaryArithmetic(op, operandVals[0]);
508 }
509
510 changed |= after->setValue(op->getResult(0), result);
511 } else if (EmitEqualityOp emitEq = mlir::dyn_cast<EmitEqualityOp>(op)) {
512 ensure(operandVals.size() == 2, "constraint op with the wrong number of operands");
513 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
514 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
515 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
516
517 // Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
518 // These patterns enforce that s is one of c0, ..., cN.
519 auto res = getGeneralizedDecompInterval(constrainRefLattice, lhsVal, rhsVal);
520 if (succeeded(res)) {
521 for (Value signalVal : res->first) {
522 changed |= applyInterval(emitEq, after, signalVal, res->second);
523 }
524 }
525
526 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
527 // Update the LHS and RHS to the same value, but restricted intervals
528 // based on the constraints
529 changed |= applyInterval(emitEq, after, lhsVal, constraint.getInterval());
530 changed |= applyInterval(emitEq, after, rhsVal, constraint.getInterval());
531 changed |= after->addSolverConstraint(constraint);
532 } else if (AssertOp assertOp = mlir::dyn_cast<AssertOp>(op)) {
533 ensure(operandVals.size() == 1, "assert op with the wrong number of operands");
534 // assert enforces that the operand is true. So we apply an interval of [1, 1]
535 // to the operand.
536 changed |= applyInterval(
537 assertOp, after, assertOp.getCondition(),
538 Interval::Degenerate(field.get(), field.get().one())
539 );
540 // Also add the solver constraint that the expression must be true.
541 auto assertExpr = operandVals[0].getScalarValue();
542 changed |= after->addSolverConstraint(assertExpr);
543 } else if (auto readf = mlir::dyn_cast<FieldReadOp>(op)) {
544 if (isSignalType(readf.getComponent().getType())) {
545 // The reg value read from the signal type is equal to the value of the Signal
546 // struct overall.
547 changed |= after->setValue(readf.getVal(), operandVals[0].getScalarValue());
548 } else if (auto storedVal =
549 before.getValue(readf.getComponent(), readf.getFieldNameAttr().getAttr());
550 succeeded(storedVal)) {
551 // The result value is the value previously written to this field.
552 changed |= after->setValue(readf.getVal(), storedVal->getScalarValue());
553 }
554 } else if (auto writef = mlir::dyn_cast<FieldWriteOp>(op)) {
555 // Update values stored in a field
556 ExpressionValue writeVal = operandVals[1].getScalarValue();
557 changed |=
558 after->setValue(writef.getComponent(), writef.getFieldNameAttr().getAttr(), writeVal);
559 // We also need to update the interval on the assigned symbol
560 ConstrainRefLatticeValue refSet = constrainRefLattice->getOrDefault(writef.getComponent());
561 if (refSet.isSingleValue()) {
562 auto fieldDefRes = writef.getFieldDefOp(tables);
563 if (succeeded(fieldDefRes)) {
564 ConstrainRefIndex idx(fieldDefRes.value());
565 ConstrainRef fieldRef = refSet.getSingleValue().createChild(idx);
566 llvm::SMTExprRef expr = getOrCreateSymbol(fieldRef);
567 changed |= after->setInterval(expr, writeVal.getInterval());
568 }
569 }
570 } else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
571 // Casts don't modify the intervals
572 changed |= after->setValue(op->getResult(0), operandVals[0].getScalarValue());
573 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
574 // Fetch the lattice of the parent operation
575 Operation *parent = op->getParentOp();
576 ensure(parent, "yield operation must have parent lattice");
577 auto *parentLattice = static_cast<IntervalAnalysisLattice *>(getLattice(parent));
578 ensure(parentLattice, "could not fetch parent lattice");
579 // Bind the operand values to the result values of the parent
580 for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
581 Value parentRes = parent->getResult(idx);
582 // Merge with the existing value, if present
583 auto exprValRes = parentLattice->getValue(parentRes);
584 ExpressionValue newResVal = operandVals[idx].getScalarValue();
585 if (succeeded(exprValRes)) {
586 ExpressionValue existingVal = exprValRes->getScalarValue();
587 newResVal =
588 existingVal.withInterval(existingVal.getInterval().join(newResVal.getInterval()));
589 } else {
590 newResVal = ExpressionValue(createFeltSymbol(parentRes), newResVal.getInterval());
591 }
592 changed |= after->setValue(parentRes, newResVal);
593 }
594
595 propagateIfChanged(parentLattice, parentLattice->join(*after));
596 } else if (
597 // We do not need to explicitly handle read ops since they are resolved at the operand value
598 // step where constrain refs are queries (with the exception of the Signal struct, see above).
599 !isReadOp(op)
600 // We do not currently handle return ops as the analysis is currently limited to constrain
601 // functions, which return no value.
602 && !isReturnOp(op)
603 // The analysis ignores definition ops.
604 && !isDefinitionOp(op)
605 // We do not need to analyze the creation of structs.
606 && !mlir::isa<CreateStructOp>(op)
607 ) {
608 op->emitWarning("unhandled operation, analysis may be incomplete");
609 }
610
611 propagateIfChanged(after, changed);
612}
613
615 auto it = refSymbols.find(r);
616 if (it != refSymbols.end()) {
617 return it->second;
618 }
619 auto sym = createFeltSymbol(r);
620 refSymbols[r] = sym;
621 return sym;
622}
623
624llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const ConstrainRef &r) const {
625 return createFeltSymbol(buildStringViaPrint(r).c_str());
626}
627
628llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v) const {
629 return createFeltSymbol(buildStringViaPrint(v).c_str());
630}
631
632llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) const {
633 return field.get().createSymbol(smtSolver, name);
634}
635
636llvm::APSInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
637 ensure(isConstOp(op), "op is not a const op");
638
639 llvm::APInt fieldConst =
640 TypeSwitch<Operation *, llvm::APInt>(op)
641 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
642 llvm::APSInt constOpVal(feltConst.getValueAttr().getValue());
643 return field.get().reduce(constOpVal);
644 })
645 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
646 return llvm::APInt(field.get().bitWidth(), indexConst.value());
647 })
648 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
649 return llvm::APInt(field.get().bitWidth(), intConst.value());
650 }).Default([](Operation *illegalOp) {
651 std::string err;
652 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
653 llvm::report_fatal_error(Twine(err));
654 return llvm::APInt();
655 });
656 return safeToSigned(fieldConst);
657}
658
659ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
660 Operation *op, const LatticeValue &a, const LatticeValue &b
661) {
662 ensure(isArithmeticOp(op), "is not arithmetic op");
663
664 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
665 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
666 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
667
668 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
669 .Case<AddFeltOp>([&](auto _) { return add(smtSolver, lhs, rhs); })
670 .Case<SubFeltOp>([&](auto _) { return sub(smtSolver, lhs, rhs); })
671 .Case<MulFeltOp>([&](auto _) { return mul(smtSolver, lhs, rhs); })
672 .Case<DivFeltOp>([&](auto divOp) { return div(smtSolver, divOp, lhs, rhs); })
673 .Case<ModFeltOp>([&](auto _) { return mod(smtSolver, lhs, rhs); })
674 .Case<AndFeltOp>([&](auto _) { return bitAnd(smtSolver, lhs, rhs); })
675 .Case<ShlFeltOp>([&](auto _) { return shiftLeft(smtSolver, lhs, rhs); })
676 .Case<ShrFeltOp>([&](auto _) { return shiftRight(smtSolver, lhs, rhs); })
677 .Case<CmpOp>([&](auto cmpOp) { return cmp(smtSolver, cmpOp, lhs, rhs); })
678 .Case<AndBoolOp>([&](auto _) { return boolAnd(smtSolver, lhs, rhs); })
679 .Case<OrBoolOp>([&](auto _) { return boolOr(smtSolver, lhs, rhs); })
680 .Case<XorBoolOp>([&](auto _) {
681 return boolXor(smtSolver, lhs, rhs);
682 }).Default([&](auto *unsupported) {
683 unsupported->emitWarning(
684 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
685 );
686 return fallbackBinaryOp(smtSolver, unsupported, lhs, rhs);
687 });
688
689 ensure(res.getExpr(), "arithmetic produced null smt expr");
690 return res;
691}
692
694IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
695 ensure(isArithmeticOp(op), "is not arithmetic op");
696
697 auto val = a.getScalarValue();
698 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
699
700 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
701 .Case<NegFeltOp>([&](auto _) { return neg(smtSolver, val); })
702 .Case<NotFeltOp>([&](auto _) { return notOp(smtSolver, val); })
703 .Case<NotBoolOp>([&](auto _) {
704 return boolNot(smtSolver, val);
705 }).Default([&](auto *unsupported) {
706 unsupported->emitWarning(
707 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
708 );
709 return fallbackUnaryOp(smtSolver, unsupported, val);
710 });
711
712 ensure(res.getExpr(), "arithmetic produced null smt expr");
713 return res;
714}
715
716ChangeResult IntervalDataFlowAnalysis::applyInterval(
717 Operation *originalOp, Lattice *after, Value val, Interval newInterval
718) {
719 auto latValRes = after->getValue(val);
720 if (failed(latValRes)) {
721 // visitOperation didn't add val to the lattice, so there's nothing to do
722 return ChangeResult::NoChange;
723 }
724 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
725 ChangeResult res = after->setValue(val, newLatticeVal);
726 // To allow the dataflow analysis to do its fixed-point iteration, we need to
727 // add the new expression to val's lattice as well.
728 Lattice *valLattice = nullptr;
729 if (auto valOp = val.getDefiningOp()) {
730 // Getting the lattice at valOp gives us the "after" lattice, but we want to
731 // update the "before" lattice so that the inputs to visitOperation will be
732 // changed.
733 if (auto prev = valOp->getPrevNode()) {
734 valLattice = getOrCreate<Lattice>(prev);
735 } else {
736 valLattice = getOrCreate<Lattice>(valOp->getBlock());
737 }
738 } else if (auto blockArg = mlir::dyn_cast<BlockArgument>(val)) {
739 Operation *owningOp = blockArg.getOwner()->getParentOp();
740 if (propagateInputConstraints) {
741 // Apply the interval from the constrain function inputs to the compute function inputs
742 auto fnOp = dyn_cast<FuncDefOp>(owningOp);
743 if (fnOp && fnOp.isStructConstrain() && blockArg.getArgNumber() > 0 &&
744 !newInterval.isEntire()) {
745 auto structOp = fnOp->getParentOfType<StructDefOp>();
746 FuncDefOp computeFn = structOp.getComputeFuncOp();
747 Operation *computeEntry = &computeFn.getRegion().front().front();
748 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
749 Lattice *computeEntryLattice = getOrCreate<Lattice>(computeEntry);
750 auto entryLatticeVal = computeEntryLattice->getValue(computeArg);
751 ExpressionValue newArgVal;
752 if (succeeded(entryLatticeVal)) {
753 newArgVal = entryLatticeVal->getScalarValue().withInterval(newInterval);
754 } else {
755 // We store the interval with an empty expression so that when the operation
756 // is visited, the expressions can be properly generated with an existing
757 // interval.
758 newArgVal = ExpressionValue(nullptr, newInterval);
759 }
760 ChangeResult computeRes = computeEntryLattice->setValue(computeArg, newArgVal);
761 propagateIfChanged(computeEntryLattice, computeRes);
762 }
763 }
764
765 valLattice = getOrCreate<Lattice>(blockArg.getOwner());
766 } else {
767 valLattice = getOrCreate<Lattice>(val);
768 }
769 ensure(valLattice, "val should have a lattice");
770 auto setNewVal = [&valLattice, &after, &val, &newLatticeVal, this]() {
771 if (valLattice != after) {
772 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
773 }
774 };
775
776 // Now we descend into val's operands, if it has any.
777 Operation *definingOp = val.getDefiningOp();
778 if (!definingOp) {
779 setNewVal();
780 return res;
781 }
782
783 const Field &f = field.get();
784
785 // This is a rules-based operation. If we have a rule for a given operation,
786 // then we can make some kind of update, otherwise we leave the intervals
787 // as is.
788 // - First we'll define all the rules so the type switch can be less messy
789
790 // cmp.<pred> restricts each side of the comparison if the result is known.
791 auto cmpCase = [&](CmpOp cmpOp) {
792 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
793 // either "true" (1) or "false" (0)
794 ensure(
795 newInterval.isBoolean(),
796 "new interval for CmpOp outside of allowed boolean range or is empty"
797 );
798 if (!newInterval.isDegenerate()) {
799 // The comparison result is unknown, so we can't update the operand ranges
800 return ChangeResult::NoChange;
801 }
802
803 bool cmpTrue = newInterval.rhs() == f.one();
804
805 Value lhs = cmpOp->getOperand(0), rhs = cmpOp->getOperand(1);
806 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
807 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
808 return ChangeResult::NoChange;
809 }
810 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
811 rhsExpr = rhsLatValRes->getScalarValue();
812
813 Interval newLhsInterval, newRhsInterval;
814 const Interval &lhsInterval = lhsExpr.getInterval();
815 const Interval &rhsInterval = rhsExpr.getInterval();
816
817 FeltCmpPredicate pred = cmpOp.getPredicate();
818 // predicate cases
819 auto eqCase = [&]() {
820 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
821 (pred == FeltCmpPredicate::NE && !cmpTrue);
822 };
823 auto neCase = [&]() {
824 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
825 (pred == FeltCmpPredicate::EQ && !cmpTrue);
826 };
827 auto ltCase = [&]() {
828 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
829 (pred == FeltCmpPredicate::GE && !cmpTrue);
830 };
831 auto leCase = [&]() {
832 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
833 (pred == FeltCmpPredicate::GT && !cmpTrue);
834 };
835 auto gtCase = [&]() {
836 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
837 (pred == FeltCmpPredicate::LE && !cmpTrue);
838 };
839 auto geCase = [&]() {
840 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
841 (pred == FeltCmpPredicate::LT && !cmpTrue);
842 };
843
844 // new intervals based on case
845 if (eqCase()) {
846 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
847 } else if (neCase()) {
848
849 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
850 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
851 // an empty value range.
852 newLhsInterval = newRhsInterval = Interval::Empty(f);
853 } else if (lhsInterval.isDegenerate()) {
854 // rhs must not overlap with lhs
855 newLhsInterval = lhsInterval;
856 newRhsInterval = rhsInterval.difference(lhsInterval);
857 } else if (rhsInterval.isDegenerate()) {
858 // lhs must not overlap with rhs
859 newLhsInterval = lhsInterval.difference(rhsInterval);
860 newRhsInterval = rhsInterval;
861 } else {
862 // Leave unchanged
863 newLhsInterval = lhsInterval;
864 newRhsInterval = rhsInterval;
865 }
866 } else if (ltCase()) {
867 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
868 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
869 } else if (leCase()) {
870 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
871 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
872 } else if (gtCase()) {
873 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
874 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
875 } else if (geCase()) {
876 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
877 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
878 } else {
879 cmpOp->emitWarning("unhandled cmp predicate");
880 return ChangeResult::NoChange;
881 }
882
883 // Now we recurse to each operand
884 return applyInterval(originalOp, after, lhs, newLhsInterval) |
885 applyInterval(originalOp, after, rhs, newRhsInterval);
886 };
887
888 // If the result of a multiplication is non-zero, then both operands must be
889 // non-zero.
890 auto mulCase = [&](MulFeltOp mulOp) {
891 auto zeroInt = Interval::Degenerate(f, f.zero());
892 if (newInterval.intersect(zeroInt).isNotEmpty()) {
893 // The multiplication may be zero, so we can't reduce the operands to be non-zero
894 return ChangeResult::NoChange;
895 }
896
897 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
898 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
899 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
900 return ChangeResult::NoChange;
901 }
902 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
903 rhsExpr = rhsLatValRes->getScalarValue();
904 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
905 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
906 return applyInterval(originalOp, after, lhs, newLhsInterval) |
907 applyInterval(originalOp, after, rhs, newRhsInterval);
908 };
909
910 // We have a special case for the Signal struct: if this value is created
911 // from reading a Signal struct's reg field, we also apply the interval to
912 // the struct itself.
913 auto readfCase = [&](FieldReadOp readfOp) {
914 Value comp = readfOp.getComponent();
915 if (isSignalType(comp.getType())) {
916 return applyInterval(originalOp, after, comp, newInterval);
917 }
918 return ChangeResult::NoChange;
919 };
920
921 // - Apply the rules given the op.
922 // NOTE: disabling clang-format for this because it makes the last case statement
923 // look ugly.
924 // clang-format off
925 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
926 .Case<CmpOp>([&](auto op) { return cmpCase(op); })
927 .Case<MulFeltOp>([&](auto op) { return mulCase(op); })
928 .Case<FieldReadOp>([&](auto op){ return readfCase(op); })
929 .Default([&](auto *_) { return ChangeResult::NoChange; });
930 // clang-format on
931
932 // Set the new val after recursion to avoid having recursive calls unset the value.
933 setNewVal();
934
935 return res;
936}
937
938FailureOr<std::pair<DenseSet<Value>, Interval>>
939IntervalDataFlowAnalysis::getGeneralizedDecompInterval(
940 const ConstrainRefLattice *constrainRefLattice, Value lhs, Value rhs
941) {
942 auto isZeroConst = [this](Value v) {
943 Operation *op = v.getDefiningOp();
944 if (!op) {
945 return false;
946 }
947 if (!isConstOp(op)) {
948 return false;
949 }
950 llvm::APSInt c = getConst(op);
951 return safeEq(c, field.get().zero());
952 };
953 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
954 Value exprTree = nullptr;
955 if (lhsIsZero && !rhsIsZero) {
956 exprTree = rhs;
957 } else if (!lhsIsZero && rhsIsZero) {
958 exprTree = lhs;
959 } else {
960 return failure();
961 }
962
963 // We now explore the expression tree for multiplications of subtractions/signal values.
964 std::optional<ConstrainRef> signalRef = std::nullopt;
965 DenseSet<Value> signalVals;
966 SmallVector<APSInt> consts;
967 SmallVector<Value> frontier {exprTree};
968 while (!frontier.empty()) {
969 Value v = frontier.back();
970 frontier.pop_back();
971 Operation *op = v.getDefiningOp();
972
973 FeltConstantOp c;
974 Value signalVal;
975 auto handleRefValue = [&constrainRefLattice, &signalRef, &signalVal, &signalVals]() {
976 ConstrainRefLatticeValue refSet = constrainRefLattice->getOrDefault(signalVal);
977 if (!refSet.isScalar() || !refSet.isSingleValue()) {
978 return failure();
979 }
980 ConstrainRef r = refSet.getSingleValue();
981 if (signalRef.has_value() && signalRef.value() != r) {
982 return failure();
983 } else if (!signalRef.has_value()) {
984 signalRef = r;
985 }
986 signalVals.insert(signalVal);
987 return success();
988 };
989
990 auto subPattern = m_CommutativeOp<SubFeltOp>(m_RefValue(&signalVal), m_Constant(&c));
991 if (op && matchPattern(op, subPattern)) {
992 if (failed(handleRefValue())) {
993 return failure();
994 }
995 auto constInt = APSInt(c.getValueAttr().getValue());
996 consts.push_back(field.get().reduce(constInt));
997 continue;
998 } else if (m_RefValue(&signalVal).match(v)) {
999 if (failed(handleRefValue())) {
1000 return failure();
1001 }
1002 consts.push_back(field.get().zero());
1003 continue;
1004 }
1005
1006 Value a, b;
1007 auto mulPattern = m_CommutativeOp<MulFeltOp>(matchers::m_Any(&a), matchers::m_Any(&b));
1008 if (op && matchPattern(op, mulPattern)) {
1009 frontier.push_back(a);
1010 frontier.push_back(b);
1011 continue;
1012 }
1013
1014 return failure();
1015 }
1016
1017 // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4),
1018 // we will create a larger range of [0, 4], since we don't support multiple intervals.
1019 std::sort(consts.begin(), consts.end());
1020 Interval iv = Interval::TypeA(field.get(), consts.front(), consts.back());
1021 return std::make_pair(std::move(signalVals), iv);
1022}
1023
1024/* StructIntervals */
1025
1026LogicalResult
1028
1029 auto computeIntervalsImpl = [&solver, &ctx, this](
1030 FuncDefOp fn,
1031 llvm::MapVector<ConstrainRef, Interval> &fieldRanges,
1032 llvm::SetVector<ExpressionValue> &solverConstraints
1033 ) {
1034 // Get the lattice at the end of the function.
1035 Operation *fnEnd = fn.getRegion().back().getTerminator();
1036
1037 const IntervalAnalysisLattice *lattice = solver.lookupState<IntervalAnalysisLattice>(fnEnd);
1038
1039 solverConstraints = lattice->getConstraints();
1040
1041 for (const auto &ref : ConstrainRef::getAllConstrainRefs(structDef, fn)) {
1042 // We only want to compute intervals for field elements and not composite types,
1043 // with the exception of the Signal struct.
1044 if (!ref.isScalar() && !ref.isSignal()) {
1045 continue;
1046 }
1047 // We also don't want to show the interval for a Signal and its internal reg.
1048 if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1049 continue;
1050 }
1051
1052 auto symbol = ctx.getSymbol(ref);
1053 auto intervalRes = lattice->findInterval(symbol);
1054 if (succeeded(intervalRes)) {
1055 fieldRanges[ref] = *intervalRes;
1056 } else {
1057 fieldRanges[ref] = Interval::Entire(ctx.field);
1058 }
1059 }
1060 };
1061
1062 computeIntervalsImpl(structDef.getComputeFuncOp(), computeFieldRanges, computeSolverConstraints);
1063 computeIntervalsImpl(
1064 structDef.getConstrainFuncOp(), constrainFieldRanges, constrainSolverConstraints
1065 );
1066
1067 return success();
1068}
1069
1070void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints, bool printCompute) const {
1071 auto writeIntervals =
1072 [&os, &withConstraints](
1073 const char *fnName, const llvm::MapVector<ConstrainRef, Interval> &fieldRanges,
1074 const llvm::SetVector<ExpressionValue> &solverConstraints, bool printName
1075 ) {
1076 int indent = 4;
1077 if (printName) {
1078 os << '\n';
1079 os.indent(indent) << fnName << " {";
1080 indent += 4;
1081 }
1082
1083 if (fieldRanges.empty()) {
1084 os << "}\n";
1085 return;
1086 }
1087
1088 for (auto &[ref, interval] : fieldRanges) {
1089 os << '\n';
1090 os.indent(indent) << ref << " in " << interval;
1091 }
1092
1093 if (withConstraints) {
1094 os << "\n\n";
1095 os.indent(indent) << "Solver Constraints { ";
1096 if (solverConstraints.empty()) {
1097 os << "}\n";
1098 } else {
1099 for (const auto &e : solverConstraints) {
1100 os << '\n';
1101 os.indent(indent + 4);
1102 e.getExpr()->print(os);
1103 }
1104 os << '\n';
1105 os.indent(indent) << '}';
1106 }
1107 }
1108
1109 if (printName) {
1110 os << '\n';
1111 os.indent(indent - 4) << '}';
1112 }
1113 };
1114
1115 os << "StructIntervals { ";
1116 if (constrainFieldRanges.empty() && (!printCompute || computeFieldRanges.empty())) {
1117 os << "}\n";
1118 return;
1119 }
1120
1121 if (printCompute) {
1122 writeIntervals(FUNC_NAME_COMPUTE, computeFieldRanges, computeSolverConstraints, printCompute);
1123 }
1124 writeIntervals(
1125 FUNC_NAME_CONSTRAIN, constrainFieldRanges, constrainSolverConstraints, printCompute
1126 );
1127
1128 os << "\n}\n";
1129}
1130
1131} // namespace llzk
MlirStringRef name
Defines an index into an LLZK object.
A value at a given point of the ConstrainRefLattice.
const ConstrainRef & getSingleValue() const
A lattice for use in dense analysis.
Defines a reference to a llzk object within a constrain function call.
ConstrainRef createChild(ConstrainRefIndex r) const
static std::vector< ConstrainRef > getAllConstrainRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, ConstrainRef root)
Produce all possible ConstraintRefs that are present starting from the given root.
Tracks a solver expression and an interval range for that expression.
const Interval & getInterval() const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:22
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:46
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
Definition Field.h:60
unsigned bitWidth() const
Definition Field.h:57
llvm::APSInt prime() const
For the prime field p, returns p.
Definition Field.h:34
Maps mlir::Values to LatticeValues.
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, Interval i)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
Intervals over a finite field.
Definition Intervals.h:214
bool isEmpty() const
Definition Intervals.h:318
static Interval True(const Field &f)
Definition Intervals.h:233
Interval intersect(const Interval &rhs) const
Intersect.
static Interval Boolean(const Field &f)
Definition Intervals.h:235
bool isDegenerate() const
Definition Intervals.h:320
static Interval False(const Field &f)
Definition Intervals.h:231
static Interval Degenerate(const Field &f, llvm::APSInt val)
Definition Intervals.h:227
Interval join(const Interval &rhs) const
Union.
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, IntervalAnalysisContext &ctx)
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:777
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
IntervalAnalysisLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
auto m_RefValue()
Definition Matchers.h:68
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
llvm::APSInt safeToSigned(llvm::APSInt i)
Safely converts the given int to a signed int if it is an unsigned int by adding an extra bit for the...
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_Constant()
Definition Matchers.h:88
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:74
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
Definition Matchers.h:47
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
llvm::SMTExprRef getSymbol(const ConstrainRef &r)