LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
12#include "llzk/Util/Debug.h"
14
15#include <llvm/ADT/TypeSwitch.h>
16
17using namespace mlir;
18
19namespace llzk {
20
21using namespace array;
22using namespace boolean;
23using namespace component;
24using namespace constrain;
25using namespace felt;
26
27/* Field */
28
29Field::Field(std::string_view primeStr) : primeMod(llvm::APSInt(primeStr)) {
30 halfPrime = (primeMod + felt(1)) / felt(2);
31}
32
33const Field &Field::getField(const char *fieldName) {
34 static llvm::DenseMap<llvm::StringRef, Field> knownFields;
35 static std::once_flag fieldsInit;
36 std::call_once(fieldsInit, initKnownFields, knownFields);
37
38 if (auto it = knownFields.find(fieldName); it != knownFields.end()) {
39 return it->second;
40 }
41 llvm::report_fatal_error("field \"" + mlir::Twine(fieldName) + "\" is unsupported");
42}
43
44void Field::initKnownFields(llvm::DenseMap<llvm::StringRef, Field> &knownFields) {
45 // bn128/254, default for circom
46 knownFields.try_emplace(
47 "bn128",
48 Field("21888242871839275222246405745257275088696311157297823662689037894645226208583")
49 );
50 knownFields.try_emplace("bn254", knownFields.at("bn128"));
51 // 15 * 2^27 + 1, default for zirgen
52 knownFields.try_emplace("babybear", Field("2013265921"));
53 // 2^64 - 2^32 + 1, used for plonky2
54 knownFields.try_emplace("goldilocks", Field("18446744069414584321"));
55 // 2^31 - 1, used for Plonky3
56 knownFields.try_emplace("mersenne31", Field("2147483647"));
57}
58
59llvm::APSInt Field::reduce(llvm::APSInt i) const {
60 unsigned maxBits = std::max(i.getBitWidth(), bitWidth());
61 llvm::APSInt m = (i.extend(maxBits) % prime().extend(maxBits)).trunc(bitWidth());
62 if (m.isNegative()) {
63 return prime() + m;
64 }
65 return m;
66}
67
68llvm::APSInt Field::reduce(unsigned i) const {
69 auto ap = llvm::APSInt(llvm::APInt(bitWidth(), i));
70 return reduce(ap);
71}
72
73/* UnreducedInterval */
74
76 if (safeGt(a, b)) {
77 return Interval::Empty(field);
78 }
79 if (safeGe(width(), field.prime())) {
80 return Interval::Entire(field);
81 }
82 auto lhs = field.reduce(a), rhs = field.reduce(b);
83 // lhs and rhs are now guaranteed to have the same bitwidth, so we can use
84 // built-in functions.
85 if ((rhs - lhs).isZero()) {
86 return Interval::Degenerate(field, lhs);
87 }
88
89 const auto &half = field.half();
90 if (lhs.ule(rhs)) {
91 if (lhs.ult(half) && rhs.ult(half)) {
92 return Interval::TypeA(field, lhs, rhs);
93 } else if (lhs.ult(half)) {
94 return Interval::TypeC(field, lhs, rhs);
95 } else {
96 return Interval::TypeB(field, lhs, rhs);
97 }
98 } else {
99 if (lhs.uge(half) && rhs.ult(half)) {
100 return Interval::TypeF(field, lhs, rhs);
101 } else {
102 return Interval::Entire(field);
103 }
104 }
105}
106
108 auto &lhs = *this;
109 return UnreducedInterval(safeMax(lhs.a, rhs.a), safeMin(lhs.b, rhs.b));
110}
111
113 auto &lhs = *this;
114 return UnreducedInterval(safeMin(lhs.a, rhs.a), safeMax(lhs.b, rhs.b));
115}
116
118 if (isEmpty() || rhs.isEmpty()) {
119 return *this;
120 }
121 auto one = llvm::APSInt(llvm::APInt(a.getBitWidth(), 1));
122 auto bound = expandingSub(rhs.b, one);
123 return UnreducedInterval(a, safeMin(b, bound));
124}
125
127 if (isEmpty() || rhs.isEmpty()) {
128 return *this;
129 }
130 return UnreducedInterval(a, safeMin(b, rhs.b));
131}
132
134 if (isEmpty() || rhs.isEmpty()) {
135 return *this;
136 }
137 auto one = llvm::APSInt(llvm::APInt(a.getBitWidth(), 1));
138 auto bound = expandingAdd(rhs.a, one);
139 return UnreducedInterval(safeMax(a, bound), b);
140}
141
143 if (isEmpty() || rhs.isEmpty()) {
144 return *this;
145 }
146 return UnreducedInterval(safeMax(a, rhs.a), b);
147}
148
150
152 llvm::APSInt low = expandingAdd(lhs.a, rhs.a), high = expandingAdd(lhs.b, rhs.b);
153 return UnreducedInterval(low, high);
154}
155
157 return lhs + (-rhs);
158}
159
161 auto v1 = expandingMul(lhs.a, rhs.a);
162 auto v2 = expandingMul(lhs.a, rhs.b);
163 auto v3 = expandingMul(lhs.b, rhs.a);
164 auto v4 = expandingMul(lhs.b, rhs.b);
165
166 auto minVal = safeMin({v1, v2, v3, v4});
167 auto maxVal = safeMax({v1, v2, v3, v4});
168
169 return UnreducedInterval(minVal, maxVal);
170}
171
173 return isNotEmpty() && rhs.isNotEmpty() && safeGe(b, rhs.a) && safeLe(a, rhs.b);
174}
175
176std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
177 if (safeLt(lhs.a, rhs.a) || (safeEq(lhs.a, rhs.a) && safeLt(lhs.b, rhs.b))) {
178 return std::strong_ordering::less;
179 }
180 if (safeGt(lhs.a, rhs.a) || (safeEq(lhs.a, rhs.a) && safeGt(lhs.b, rhs.b))) {
181 return std::strong_ordering::greater;
182 }
183 return std::strong_ordering::equal;
184}
185
186llvm::APSInt UnreducedInterval::width() const {
187 llvm::APSInt w;
188 if (safeGt(a, b)) {
189 // This would be reduced to an empty Interval, so the width is just zero.
190 w = llvm::APSInt::getUnsigned(0);
191 } else {
193 w = expandingSub(b, a)++;
194 }
195 ensure(safeGe(w, llvm::APSInt::getUnsigned(0)), "cannot have negative width");
196 return w;
197}
198
199bool UnreducedInterval::isEmpty() const { return safeEq(width(), llvm::APSInt::getUnsigned(0)); }
200
201/* Interval */
202
204 if (isEmpty()) {
205 return UnreducedInterval(field.get().zero(), field.get().zero());
206 }
207 if (isEntire()) {
208 return UnreducedInterval(field.get().zero(), field.get().maxVal());
209 }
210 return UnreducedInterval(a, b);
211}
212
214 if (is<Type::TypeF>()) {
215 return UnreducedInterval(field.get().prime() - a, b);
216 }
217 return toUnreduced();
218}
219
221 ensure(is<Type::TypeA, Type::TypeB, Type::TypeC>(), "unsupported range type");
222 return UnreducedInterval(a - field.get().prime(), b - field.get().prime());
223}
224
226 auto &lhs = *this;
227 ensure(
228 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
229 );
230 const Field &f = lhs.getField();
231
232 // Trivial cases
233 if (lhs.isEntire() || rhs.isEntire()) {
234 return Interval::Entire(f);
235 }
236 if (lhs.isEmpty()) {
237 return rhs;
238 }
239 if (rhs.isEmpty()) {
240 return lhs;
241 }
242 if (lhs.isDegenerate() || rhs.isDegenerate()) {
243 return lhs.toUnreduced().doUnion(rhs.toUnreduced()).reduce(f);
244 }
245
246 // More complex cases
247 if (areOneOf<
250 return Interval(rhs.ty, f, std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
251 }
253 auto lhsUnred = lhs.firstUnreduced();
254 auto opt1 = rhs.firstUnreduced().doUnion(lhsUnred);
255 auto opt2 = rhs.secondUnreduced().doUnion(lhsUnred);
256 if (opt1.width() <= opt2.width()) {
257 return opt1.reduce(f);
258 }
259 return opt2.reduce(f);
260 }
262 return lhs.firstUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
263 }
265 return lhs.secondUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
266 }
268 return Interval::Entire(f);
269 }
270 if (areOneOf<
273 lhs, rhs
274 )) {
275 return rhs.join(lhs);
276 }
277 llvm::report_fatal_error("unhandled join case");
278 return Interval::Entire(f);
279}
280
282 auto &lhs = *this;
283 ensure(
284 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
285 );
286
287 // Trivial cases
288 if (lhs.isEmpty() || rhs.isEmpty()) {
289 return Interval::Empty(field.get());
290 }
291 if (lhs.isEntire()) {
292 return rhs;
293 }
294 if (rhs.isEntire()) {
295 return lhs;
296 }
297 if (lhs.isDegenerate() || rhs.isDegenerate()) {
298 return lhs.toUnreduced().intersect(rhs.toUnreduced()).reduce(field.get());
299 }
300
301 // More complex cases
302 if (areOneOf<
305 auto maxA = std::max(lhs.a, rhs.a);
306 auto minB = std::min(lhs.b, rhs.b);
307 if (maxA <= minB) {
308 return Interval(lhs.ty, field.get(), maxA, minB);
309 } else {
310 return Interval::Empty(field.get());
311 }
312 }
314 return Interval::Empty(field.get());
315 }
317 return lhs.firstUnreduced().intersect(rhs.firstUnreduced()).reduce(field.get());
318 }
320 return lhs.secondUnreduced().intersect(rhs.firstUnreduced()).reduce(field.get());
321 }
323 auto rhsUnred = rhs.firstUnreduced();
324 auto opt1 = lhs.firstUnreduced().intersect(rhsUnred).reduce(field.get());
325 auto opt2 = lhs.secondUnreduced().intersect(rhsUnred).reduce(field.get());
326 ensure(!opt1.isEntire() && !opt2.isEntire(), "impossible intersection");
327 if (opt1.isEmpty()) {
328 return opt2;
329 }
330 if (opt2.isEmpty()) {
331 return opt1;
332 }
333 return opt1.join(opt2);
334 }
335 if (areOneOf<
338 lhs, rhs
339 )) {
340 return rhs.intersect(lhs);
341 }
342 return Interval::Empty(field.get());
343}
344
346 // intersect checks that we're in the same field
348 if (intersection.isEmpty()) {
349 // There's nothing to remove, so just return this
350 return *this;
351 }
352
353 const Field &f = field.get();
354
355 // Trivial cases with a non-empty intersection
356 if (isDegenerate() || other.isEntire()) {
357 return Interval::Empty(f);
358 }
359 if (isEntire()) {
360 // Since we don't support punching arbitrary holes in ranges, we only reduce
361 // entire ranges if other is [0, b] or [a, prime - 1]
362 if (other.a == f.zero()) {
363 return UnreducedInterval(other.b + f.one(), f.maxVal()).reduce(f);
364 }
365 if (other.b == f.maxVal()) {
366 return UnreducedInterval(f.zero(), other.a - f.one()).reduce(f);
367 }
368
369 return *this;
370 }
371
372 // Non-trivial cases
373 // - Internal+internal or external+external cases
376 areOneOf<{Type::TypeF, Type::TypeF}>(*this, intersection)) {
377 // The intersection needs to be at the end of the interval, otherwise we would
378 // split the interval in two, and we aren't set up to support multiple intervals
379 // per value.
380 if (a != intersection.a && b != intersection.b) {
381 return *this;
382 }
383 // Otherwise, remove the intersection and reduce
384 if (a == intersection.a) {
385 return UnreducedInterval(intersection.b + f.one(), b).reduce(f);
386 }
387 // else b == intersection.b
388 return UnreducedInterval(a, intersection.a - f.one()).reduce(f);
389 }
390 // - Mixed internal/external cases. We flip the comparison
391 if (isTypeF()) {
392 if (a != intersection.b && b != intersection.a) {
393 return *this;
394 }
395 // Otherwise, remove the intersection and reduce
396 if (a == intersection.b) {
397 return UnreducedInterval(intersection.a + f.one(), b).reduce(f);
398 }
399 // else b == intersection.a
400 return UnreducedInterval(a, intersection.b - f.one()).reduce(f);
401 }
402
403 // In cases we don't know how to handle, we over-approximate and return
404 // the original interval.
405 return *this;
406}
407
408Interval Interval::operator-() const { return (-firstUnreduced()).reduce(field.get()); }
409
411 ensure(lhs.field.get() == rhs.field.get(), "cannot add intervals in different fields");
412 return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(lhs.field.get());
413}
414
415Interval operator-(const Interval &lhs, const Interval &rhs) { return lhs + (-rhs); }
416
418 ensure(lhs.field.get() == rhs.field.get(), "cannot multiply intervals in different fields");
419 const auto &field = lhs.field.get();
420 auto zeroInterval = Interval::Degenerate(field, field.zero());
421 if (lhs == zeroInterval || rhs == zeroInterval) {
422 return zeroInterval;
423 }
424 if (lhs.isEmpty() || rhs.isEmpty()) {
425 return Interval::Empty(field);
426 }
427 if (lhs.isEntire() || rhs.isEntire()) {
428 return Interval::Entire(field);
429 }
430
432 return (lhs.secondUnreduced() * rhs.secondUnreduced()).reduce(field);
433 }
434 return (lhs.firstUnreduced() * rhs.firstUnreduced()).reduce(field);
435}
436
437FailureOr<Interval> operator/(const Interval &lhs, const Interval &rhs) {
438 ensure(lhs.getField() == rhs.getField(), "cannot divide intervals in different fields");
439 const auto &field = rhs.getField();
440 if (rhs.width() > field.one()) {
441 return Interval::Entire(field);
442 }
443 if (rhs.a.isZero()) {
444 return failure();
445 }
446 return success(UnreducedInterval(lhs.a / rhs.a, lhs.b / rhs.a).reduce(field));
447}
448
450 ensure(
451 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
452 );
453 const auto &field = rhs.getField();
454 return UnreducedInterval(field.zero(), rhs.b).reduce(field);
455}
456
457void Interval::print(mlir::raw_ostream &os) const {
458 os << TypeName(ty);
459 if (is<Type::Degenerate>()) {
460 os << '(' << a << ')';
461 } else if (!is<Type::Entire, Type::Empty>()) {
462 os << ":[ " << a << ", " << b << " ]";
463 }
464}
465
466/* ExpressionValue */
467
469 if (expr == nullptr && rhs.expr == nullptr) {
470 return i == rhs.i;
471 }
472 if (expr == nullptr || rhs.expr == nullptr) {
473 return false;
474 }
475 return i == rhs.i && *expr == *rhs.expr;
476}
477
479intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
480 Interval res = lhs.i.intersect(rhs.i);
481 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
482 return ExpressionValue(exprEq, res);
483}
484
486add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
487 ExpressionValue res;
488 res.i = lhs.i + rhs.i;
489 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
490 return res;
491}
492
494sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
495 ExpressionValue res;
496 res.i = lhs.i - rhs.i;
497 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
498 return res;
499}
500
502mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
503 ExpressionValue res;
504 res.i = lhs.i * rhs.i;
505 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
506 return res;
507}
508
510div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs,
511 const ExpressionValue &rhs) {
512 ExpressionValue res;
513 auto divRes = lhs.i / rhs.i;
514 if (failed(divRes)) {
515 op->emitWarning(
516 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
517 " Range of division result will be treated as unbounded."
518 );
519 res.i = Interval::Entire(lhs.getField());
520 } else {
521 res.i = *divRes;
522 }
523 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
524 return res;
525}
526
528mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
529 ExpressionValue res;
530 res.i = lhs.i % rhs.i;
531 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
532 return res;
533}
534
536cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs) {
537 ExpressionValue res;
538 res.i = Interval::Boolean(lhs.getField());
539 switch (op.getPredicate()) {
540 case FeltCmpPredicate::EQ:
541 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
542 res.i = lhs.i.intersect(rhs.i);
543 break;
544 case FeltCmpPredicate::NE:
545 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
546 break;
547 case FeltCmpPredicate::LT:
548 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
549 break;
550 case FeltCmpPredicate::LE:
551 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
552 break;
553 case FeltCmpPredicate::GT:
554 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
555 break;
556 case FeltCmpPredicate::GE:
557 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
558 break;
559 }
560 return res;
561}
562
564 llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs
565) {
566 ExpressionValue res;
567 res.i = Interval::Entire(lhs.getField());
568 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
569 .Case<AndFeltOp>([&](AndFeltOp _) { return solver->mkBVAnd(lhs.expr, rhs.expr); })
570 .Case<OrFeltOp>([&](OrFeltOp _) { return solver->mkBVOr(lhs.expr, rhs.expr); })
571 .Case<XorFeltOp>([&](XorFeltOp _) { return solver->mkBVXor(lhs.expr, rhs.expr); })
572 .Case<ShlFeltOp>([&](ShlFeltOp _) { return solver->mkBVShl(lhs.expr, rhs.expr); })
573 .Case<ShrFeltOp>([&](ShrFeltOp _) {
574 return solver->mkBVLshr(lhs.expr, rhs.expr);
575 }).Default([&](Operation *unsupported) {
576 llvm::report_fatal_error(
577 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
578 );
579 return nullptr;
580 });
581
582 return res;
583}
584
585ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val) {
586 ExpressionValue res;
587 res.i = -val.i;
588 res.expr = solver->mkBVNeg(val.expr);
589 return res;
590}
591
592ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val) {
593 ExpressionValue res;
594 const auto &f = val.getField();
595 if (val.i.isDegenerate()) {
596 if (val.i == Interval::Degenerate(f, f.zero())) {
597 res.i = Interval::Degenerate(f, f.one());
598 } else {
599 res.i = Interval::Degenerate(f, f.zero());
600 }
601 }
602 res.i = Interval::Boolean(f);
603 res.expr = solver->mkBVNot(val.expr);
604 return res;
605}
606
608fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val) {
609 const Field &field = val.getField();
610 ExpressionValue res;
611 res.i = Interval::Entire(field);
612 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
613 .Case<InvFeltOp>([&](InvFeltOp _) {
614 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
615 // To create this expression, we create a new symbol for Y and add the
616 // XY % prime = 1 constraint to the solver.
617 std::string symName = buildStringViaInsertionOp(*op);
618 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
619 llvm::SMTExprRef one = solver->mkBitvector(field.one(), field.bitWidth());
620 llvm::SMTExprRef prime = solver->mkBitvector(field.prime(), field.bitWidth());
621 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
622 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
623 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
624 solver->addConstraint(constraint);
625 return invSym;
626 }).Default([&](Operation *unsupported) {
627 llvm::report_fatal_error(
628 "no fallback provided for " + mlir::Twine(op->getName().getStringRef())
629 );
630 return nullptr;
631 });
632
633 return res;
634}
635
636void ExpressionValue::print(mlir::raw_ostream &os) const {
637 if (expr) {
638 expr->print(os);
639 } else {
640 os << "<null expression>";
641 }
642
643 os << " ( interval: " << i << " )";
644}
645
646/* IntervalAnalysisLattice */
647
648ChangeResult IntervalAnalysisLattice::join(const AbstractDenseLattice &other) {
649 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
650 if (!rhs) {
651 llvm::report_fatal_error("invalid join lattice type");
652 }
653 ChangeResult res = ChangeResult::NoChange;
654 for (auto &[k, v] : rhs->valMap) {
655 auto it = valMap.find(k);
656 if (it == valMap.end() || it->second != v) {
657 valMap[k] = v;
658 res |= ChangeResult::Change;
659 }
660 }
661 for (auto &v : rhs->constraints) {
662 if (!constraints.contains(v)) {
663 constraints.insert(v);
664 res |= ChangeResult::Change;
665 }
666 }
667 for (auto &[e, i] : rhs->intervals) {
668 auto it = intervals.find(e);
669 if (it == intervals.end() || it->second != i) {
670 intervals[e] = i;
671 res |= ChangeResult::Change;
672 }
673 }
674 return res;
675}
676
677void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
678 os << "IntervalAnalysisLattice { ";
679 for (auto &[ref, val] : valMap) {
680 os << "\n (valMap) " << ref << " := " << val;
681 }
682 for (auto &[expr, interval] : intervals) {
683 os << "\n (intervals) ";
684 expr->print(os);
685 os << " in " << interval;
686 }
687 if (!valMap.empty()) {
688 os << '\n';
689 }
690 os << '}';
691}
692
693FailureOr<IntervalAnalysisLattice::LatticeValue> IntervalAnalysisLattice::getValue(Value v) const {
694 auto it = valMap.find(v);
695 if (it == valMap.end()) {
696 return failure();
697 }
698 return it->second;
699}
700
702 LatticeValue val(e);
703 if (valMap[v] == val) {
704 return ChangeResult::NoChange;
705 }
706 valMap[v] = val;
707 intervals[e.getExpr()] = e.getInterval();
708 return ChangeResult::Change;
709}
710
712 if (!constraints.contains(e)) {
713 constraints.insert(e);
714 return ChangeResult::Change;
715 }
716 return ChangeResult::NoChange;
717}
718
719FailureOr<Interval> IntervalAnalysisLattice::findInterval(llvm::SMTExprRef expr) const {
720 auto it = intervals.find(expr);
721 if (it != intervals.end()) {
722 return it->second;
723 }
724 return failure();
725}
726
727/* IntervalDataFlowAnalysis */
728
733 CallOpInterface call, dataflow::CallControlFlowAction action,
735) {
739 if (action == dataflow::CallControlFlowAction::EnterCallee) {
740 // We skip updating the incoming lattice for function calls,
741 // as values are relative to the containing function/struct, so we don't need to pollute
742 // the callee with the callers values.
743 setToEntryState(after);
744 }
748 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
749 // Get the argument values of the lattice by getting the state as it would
750 // have been for the callsite.
751 dataflow::AbstractDenseLattice *beforeCall = nullptr;
752 if (auto *prev = call->getPrevNode()) {
753 beforeCall = getLattice(prev);
754 } else {
755 beforeCall = getLattice(call->getBlock());
756 }
757 ensure(beforeCall, "could not get prior lattice");
758
759 // The lattice at the return is the lattice before the call
760 propagateIfChanged(after, after->join(*beforeCall));
761 }
766 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
767 // For external calls, we propagate what information we already have from
768 // before the call to after the call, since the external call won't invalidate
769 // any of that information. It also, conservatively, makes no assumptions about
770 // external calls and their computation, so CDG edges will not be computed over
771 // input arguments to external functions.
772 join(after, before);
773 }
774}
775
777 Operation *op, const Lattice &before, Lattice *after
778) {
779 ChangeResult changed = after->join(before);
780
781 llvm::SmallVector<LatticeValue> operandVals;
782
783 auto constrainRefLattice = dataflowSolver.lookupState<ConstrainRefLattice>(op);
784 ensure(constrainRefLattice, "failed to get lattice");
785
786 for (auto &operand : op->getOpOperands()) {
787 auto val = operand.get();
788 // First, lookup the operand value in the before state.
789 auto priorState = before.getValue(val);
790 if (succeeded(priorState)) {
791 operandVals.push_back(*priorState);
792 continue;
793 }
794
795 // Else, look up the stored value by constrain ref.
796 // We only care about scalar type values, so we ignore composite types, which
797 // are currently limited to non-Signal structs and arrays.
798 Type valTy = val.getType();
799 if (mlir::isa<ArrayType, StructType>(valTy) && !isSignalType(valTy)) {
800 operandVals.push_back(LatticeValue());
801 continue;
802 }
803
804 ConstrainRefLatticeValue refSet = constrainRefLattice->getOrDefault(val);
805 ensure(refSet.isScalar(), "should have ruled out array values already");
806
807 if (refSet.getScalarValue().empty()) {
808 // If we can't compute the reference, then there must be some unsupported
809 // op the reference analysis cannot handle. We emit a warning and return
810 // early, since there's no meaningful computation we can do for this op.
811 std::string warning;
812 debug::Appender(warning
813 ) << "state of "
814 << val << " is empty; defining operation is unsupported by constrain ref analysis";
815 op->emitWarning(warning);
816 propagateIfChanged(after, changed);
817 return;
818 } else if (!refSet.isSingleValue()) {
819 std::string warning;
820 debug::Appender(warning) << "operand " << val << " is not a single value " << refSet
821 << ", overapproximating";
822 op->emitWarning(warning);
823 // Here, we will override the prior lattice value with a new symbol, representing
824 // "any" value, then use that value for the operands.
825 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
826 changed |= after->setValue(val, anyVal);
827 operandVals.emplace_back(anyVal);
828 } else {
829 auto ref = refSet.getSingleValue();
830 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
831 changed |= after->setValue(val, exprVal);
832 operandVals.emplace_back(exprVal);
833 }
834 }
835
836 // Now, the way we update is dependent on the type of the operation.
837 if (!isConsideredOp(op)) {
838 op->emitWarning("unconsidered operation type, analysis may be incomplete");
839 }
840
841 if (isConstOp(op)) {
842 auto constVal = getConst(op);
843 auto expr = createConstBitvectorExpr(constVal);
844 ExpressionValue latticeVal(field.get(), expr, constVal);
845 changed |= after->setValue(op->getResult(0), latticeVal);
846 } else if (isArithmeticOp(op)) {
847 ensure(operandVals.size() <= 2, "arithmetic op with the wrong number of operands");
848 ExpressionValue result;
849 if (operandVals.size() == 2) {
850 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
851 } else {
852 result = performUnaryArithmetic(op, operandVals[0]);
853 }
854
855 changed |= after->setValue(op->getResult(0), result);
856 } else if (EmitEqualityOp emitEq = mlir::dyn_cast<EmitEqualityOp>(op)) {
857 ensure(operandVals.size() == 2, "constraint op with the wrong number of operands");
858 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
859 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
860 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
861
862 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
863 // Update the LHS and RHS to the same value, but restricted intervals
864 // based on the constraints
865 changed |= applyInterval(emitEq, after, lhsVal, constraint.getInterval());
866 changed |= applyInterval(emitEq, after, rhsVal, constraint.getInterval());
867 changed |= after->addSolverConstraint(constraint);
868 } else if (AssertOp assertOp = mlir::dyn_cast<AssertOp>(op)) {
869 ensure(operandVals.size() == 1, "assert op with the wrong number of operands");
870 // assert enforces that the operand is true. So we apply an interval of [1, 1]
871 // to the operand.
872 changed |= applyInterval(
873 assertOp, after, assertOp.getCondition(),
874 Interval::Degenerate(field.get(), field.get().one())
875 );
876 // Also add the solver constraint that the expression must be true.
877 auto assertExpr = operandVals[0].getScalarValue();
878 changed |= after->addSolverConstraint(assertExpr);
879 } else if (auto readf = mlir::dyn_cast<FieldReadOp>(op);
880 readf && isSignalType(readf.getComponent().getType())) {
881 // The reg value read from the signal type is equal to the value of the Signal
882 // struct overall.
883 changed |= after->setValue(readf.getVal(), operandVals[0].getScalarValue());
884 } else if (!isReadOp(op) /* We do not need to explicitly handle read ops
885 since they are resolved at the operand value step where constrain refs are
886 queries (with the exception of the Signal struct, see above). */
887 && !isReturnOp(op) /* We do not currently handle return ops as the analysis
888 is currently limited to constrain functions, which return no value. */
889 && !isDefinitionOp(op) /* The analysis ignores definition ops. */
890 &&
891 !mlir::isa<CreateStructOp>(op) /* We do not need to analyze the creation of structs. */
892 ) {
893 op->emitWarning("unhandled operation, analysis may be incomplete");
894 }
895
896 propagateIfChanged(after, changed);
897}
898
900 auto it = refSymbols.find(r);
901 if (it != refSymbols.end()) {
902 return it->second;
903 }
904 auto sym = createFeltSymbol(r);
905 refSymbols[r] = sym;
906 return sym;
907}
908
909llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const ConstrainRef &r) const {
910 return createFeltSymbol(buildStringViaPrint(r).c_str());
911}
912
913llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v) const {
914 return createFeltSymbol(buildStringViaPrint(v).c_str());
915}
916
917llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) const {
918 return field.get().createSymbol(smtSolver, name);
919}
920
921llvm::APSInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
922 ensure(isConstOp(op), "op is not a const op");
923
924 llvm::APInt fieldConst =
925 TypeSwitch<Operation *, llvm::APInt>(op)
926 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
927 llvm::APSInt constOpVal(feltConst.getValueAttr().getValue());
928 return field.get().reduce(constOpVal);
929 })
930 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
931 return llvm::APInt(field.get().bitWidth(), indexConst.value());
932 })
933 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
934 return llvm::APInt(field.get().bitWidth(), intConst.value());
935 }).Default([&](Operation *illegalOp) {
936 std::string err;
937 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
938 llvm::report_fatal_error(Twine(err));
939 return llvm::APInt();
940 });
941 return llvm::APSInt(fieldConst);
942}
943
944ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
945 Operation *op, const LatticeValue &a, const LatticeValue &b
946) {
947 ensure(isArithmeticOp(op), "is not arithmetic op");
948
949 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
950 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
951 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
952
953 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
954 .Case<AddFeltOp>([&](AddFeltOp _) { return add(smtSolver, lhs, rhs); })
955 .Case<SubFeltOp>([&](SubFeltOp _) { return sub(smtSolver, lhs, rhs); })
956 .Case<MulFeltOp>([&](MulFeltOp _) { return mul(smtSolver, lhs, rhs); })
957 .Case<DivFeltOp>([&](DivFeltOp divOp) { return div(smtSolver, divOp, lhs, rhs); })
958 .Case<ModFeltOp>([&](ModFeltOp _) { return mod(smtSolver, lhs, rhs); })
959 .Case<CmpOp>([&](CmpOp cmpOp) {
960 return cmp(smtSolver, cmpOp, lhs, rhs);
961 }).Default([&](Operation *unsupported) {
962 unsupported->emitWarning(
963 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
964 );
965 return fallbackBinaryOp(smtSolver, unsupported, lhs, rhs);
966 });
967
968 ensure(res.getExpr(), "arithmetic produced null smt expr");
969 return res;
970}
971
972ExpressionValue
973IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
974 ensure(isArithmeticOp(op), "is not arithmetic op");
975
976 auto val = a.getScalarValue();
977 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
978
979 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
980 .Case<NegFeltOp>([&](NegFeltOp _) { return neg(smtSolver, val); })
981 .Case<NotFeltOp>([&](NotFeltOp _) {
982 return notOp(smtSolver, val);
983 }).Default([&](Operation *unsupported) {
984 unsupported->emitWarning(
985 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
986 );
987 return fallbackUnaryOp(smtSolver, unsupported, val);
988 });
989
990 ensure(res.getExpr(), "arithmetic produced null smt expr");
991 return res;
992}
993
994ChangeResult IntervalDataFlowAnalysis::applyInterval(
995 Operation *originalOp, Lattice *after, Value val, Interval newInterval
996) {
997 auto latValRes = after->getValue(val);
998 if (failed(latValRes)) {
999 // visitOperation didn't add val to the lattice, so there's nothing to do
1000 return ChangeResult::NoChange;
1001 }
1002 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
1003 ChangeResult res = after->setValue(val, newLatticeVal);
1004 // To allow the dataflow analysis to do its fixed-point iteration, we need to
1005 // add the new expression to val's lattice as well.
1006 Lattice *valLattice = nullptr;
1007 if (auto valOp = val.getDefiningOp()) {
1008 // Getting the lattice at valOp gives us the "after" lattice, but we want to
1009 // update the "before" lattice so that the inputs to visitOperation will be
1010 // changed.
1011 if (auto prev = valOp->getPrevNode()) {
1012 valLattice = getOrCreate<Lattice>(prev);
1013 } else {
1014 valLattice = getOrCreate<Lattice>(valOp->getBlock());
1015 }
1016 } else if (auto blockArg = mlir::dyn_cast<BlockArgument>(val)) {
1017 valLattice = getOrCreate<Lattice>(blockArg.getOwner());
1018 } else {
1019 valLattice = getOrCreate<Lattice>(val);
1020 }
1021 ensure(valLattice, "val should have a lattice");
1022 if (valLattice != after) {
1023 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
1024 }
1025
1026 // Now we descend into val's operands, if it has any.
1027 Operation *definingOp = val.getDefiningOp();
1028 if (!definingOp) {
1029 return res;
1030 }
1031
1032 const Field &f = field.get();
1033
1034 // This is a rules-based operation. If we have a rule for a given operation,
1035 // then we can make some kind of update, otherwise we leave the intervals
1036 // as is.
1037 // - First we'll define all the rules so the type switch can be less messy
1038
1039 // cmp.<pred> restricts each side of the comparison if the result is known.
1040 auto cmpCase = [&](CmpOp cmpOp) {
1041 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
1042 // either "true" (1) or "false" (0)
1043 Interval maxInterval = Interval::Boolean(f);
1044 ensure(
1045 newInterval.intersect(maxInterval).isNotEmpty(),
1046 "new interval for CmpOp outside of allowed boolean range or is empty"
1047 );
1048 if (!newInterval.isDegenerate()) {
1049 // The comparison result is unknown, so we can't update the operand ranges
1050 return ChangeResult::NoChange;
1051 }
1052
1053 bool cmpTrue = newInterval.rhs() == f.one();
1054
1055 Value lhs = cmpOp->getOperand(0), rhs = cmpOp->getOperand(1);
1056 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
1057 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
1058 return ChangeResult::NoChange;
1059 }
1060 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
1061 rhsExpr = rhsLatValRes->getScalarValue();
1062
1063 Interval newLhsInterval, newRhsInterval;
1064 const Interval &lhsInterval = lhsExpr.getInterval();
1065 const Interval &rhsInterval = rhsExpr.getInterval();
1066
1067 FeltCmpPredicate pred = cmpOp.getPredicate();
1068 // predicate cases
1069 auto eqCase = [&]() {
1070 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1071 (pred == FeltCmpPredicate::NE && !cmpTrue);
1072 };
1073 auto neCase = [&]() {
1074 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1075 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1076 };
1077 auto ltCase = [&]() {
1078 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1079 (pred == FeltCmpPredicate::GE && !cmpTrue);
1080 };
1081 auto leCase = [&]() {
1082 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1083 (pred == FeltCmpPredicate::GT && !cmpTrue);
1084 };
1085 auto gtCase = [&]() {
1086 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1087 (pred == FeltCmpPredicate::LE && !cmpTrue);
1088 };
1089 auto geCase = [&]() {
1090 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1091 (pred == FeltCmpPredicate::LT && !cmpTrue);
1092 };
1093
1094 // new intervals based on case
1095 if (eqCase()) {
1096 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1097 } else if (neCase()) {
1098 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1099 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
1100 // an empty value range.
1101 newLhsInterval = newRhsInterval = Interval::Empty(f);
1102 } else {
1103 // Leave unchanged
1104 newLhsInterval = lhsInterval;
1105 newRhsInterval = rhsInterval;
1106 }
1107 } else if (ltCase()) {
1108 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1109 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1110 } else if (leCase()) {
1111 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1112 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1113 } else if (gtCase()) {
1114 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1115 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1116 } else if (geCase()) {
1117 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1118 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1119 } else {
1120 cmpOp->emitWarning("unhandled cmp predicate");
1121 return ChangeResult::NoChange;
1122 }
1123
1124 // Now we recurse to each operand
1125 return applyInterval(originalOp, after, lhs, newLhsInterval) |
1126 applyInterval(originalOp, after, rhs, newRhsInterval);
1127 };
1128
1129 // If the result of a multiplication is non-zero, then both operands must be
1130 // non-zero.
1131 auto mulCase = [&](MulFeltOp mulOp) {
1132 auto zeroInt = Interval::Degenerate(f, f.zero());
1133 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1134 // The multiplication may be zero, so we can't reduce the operands to be non-zero
1135 return ChangeResult::NoChange;
1136 }
1137
1138 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
1139 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
1140 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
1141 return ChangeResult::NoChange;
1142 }
1143 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
1144 rhsExpr = rhsLatValRes->getScalarValue();
1145 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1146 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1147 return applyInterval(originalOp, after, lhs, newLhsInterval) |
1148 applyInterval(originalOp, after, rhs, newRhsInterval);
1149 };
1150
1151 // We have a special case for the Signal struct: if this value is created
1152 // from reading a Signal struct's reg field, we also apply the interval to
1153 // the struct itself.
1154 auto readfCase = [&](FieldReadOp readfOp) {
1155 Value comp = readfOp.getComponent();
1156 if (isSignalType(comp.getType())) {
1157 return applyInterval(originalOp, after, comp, newInterval);
1158 }
1159 return ChangeResult::NoChange;
1160 };
1161
1162 // - Apply the rules given the op.
1163 // NOTE: disabling clang-format for this because it makes the last case statement
1164 // look ugly.
1165 // clang-format off
1166 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
1167 .Case<CmpOp>([&](CmpOp op) { return cmpCase(op); })
1168 .Case<MulFeltOp>([&](MulFeltOp op) { return mulCase(op); })
1169 .Case<FieldReadOp>([&](FieldReadOp op){ return readfCase(op); })
1170 .Default([&](Operation *_) { return ChangeResult::NoChange; });
1171 // clang-format on
1172
1173 return res;
1174}
1175
1176/* StructIntervals */
1177
1179 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx
1180) {
1181 // Get the lattice at the end of the constrain function.
1182 function::ReturnOp constrainEnd;
1183 structDef.getConstrainFuncOp().walk([&constrainEnd](function::ReturnOp r) mutable {
1184 constrainEnd = r;
1185 });
1186
1187 const IntervalAnalysisLattice *constrainLattice =
1188 solver.lookupState<IntervalAnalysisLattice>(constrainEnd);
1189
1190 constrainSolverConstraints = constrainLattice->getConstraints();
1191
1192 for (const auto &ref : ConstrainRef::getAllConstrainRefs(structDef)) {
1193 // We only want to compute intervals for field elements and not composite types,
1194 // with the exception of the Signal struct.
1195 if (!ref.isScalar() && !ref.isSignal()) {
1196 continue;
1197 }
1198 // We also don't want to show the interval for a Signal and its internal reg.
1199 if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1200 continue;
1201 }
1202 auto symbol = ctx.getSymbol(ref);
1203 auto constrainInterval = constrainLattice->findInterval(symbol);
1204 if (succeeded(constrainInterval)) {
1205 constrainFieldRanges[ref] = *constrainInterval;
1206 } else {
1207 constrainFieldRanges[ref] = Interval::Entire(ctx.field);
1208 }
1209 }
1210
1211 return success();
1212}
1213
1214void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints) const {
1215 os << "StructIntervals { ";
1216 if (constrainFieldRanges.empty()) {
1217 os << "}\n";
1218 return;
1219 }
1220
1221 for (auto &[ref, interval] : constrainFieldRanges) {
1222 os << "\n " << ref << " in " << interval;
1223 }
1224
1225 if (withConstraints) {
1226 os << "\n\n Solver Constraints { ";
1227 if (constrainSolverConstraints.empty()) {
1228 os << "}\n";
1229 } else {
1230 for (const auto &e : constrainSolverConstraints) {
1231 os << "\n ";
1232 e.getExpr()->print(os);
1233 }
1234 os << "\n }";
1235 }
1236 }
1237
1238 os << "\n}\n";
1239}
1240
1241} // namespace llzk
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
A value at a given point of the ConstrainRefLattice.
const ConstrainRef & getSingleValue() const
A lattice for use in dense analysis.
Defines a reference to a llzk object within a constrain function call.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
friend ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
const Interval & getInterval() const
friend ExpressionValue cmp(llvm::SMTSolverRef solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue div(llvm::SMTSolverRef solver, felt::DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
const Field & getField() const
friend ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
Information about the prime finite field used for the interval analysis.
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::APSInt zero() const
Returns 0 at the bitwidth of the field.
Field()=delete
llvm::APSInt half() const
Returns p / 2.
unsigned bitWidth() const
llvm::APSInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
static const Field & getField(const char *fieldName)
Get a Field from a given field name string.
llvm::APSInt prime() const
For the prime field p, returns p.
llvm::APSInt reduce(llvm::APSInt i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
Maps mlir::Values to LatticeValues.
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
Intervals over a finite field.
bool isEmpty() const
Interval intersect(const Interval &rhs) const
Intersect.
static std::string_view TypeName(Type t)
llvm::APSInt rhs() const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
bool isDegenerate() const
void print(mlir::raw_ostream &os) const
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
static Interval TypeC(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeF(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeB(const Field &f, llvm::APSInt a, llvm::APSInt b)
bool isTypeF() const
friend Interval operator*(const Interval &lhs, const Interval &rhs)
llvm::APSInt lhs() const
static Interval Empty(const Field &f)
static bool areOneOf(const Interval &a, const Interval &b)
static Interval Degenerate(const Field &f, llvm::APSInt val)
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
friend mlir::FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Returns failure if a division-by-zero is encountered.
static Interval TypeA(const Field &f, llvm::APSInt a, llvm::APSInt b)
friend Interval operator+(const Interval &lhs, const Interval &rhs)
bool isEntire() const
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
friend Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
void print(mlir::raw_ostream &os, bool withConstraints=false) const
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
UnreducedInterval operator-() const
friend UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
bool isEmpty() const
Returns true iff width() is zero.
UnreducedInterval(llvm::APSInt x, llvm::APSInt y)
llvm::APSInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
friend std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
bool overlaps(const UnreducedInterval &rhs) const
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
friend UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:777
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
IntervalAnalysisLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingSub(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely subtract lhs and rhs, expanding the width of the result as necessary.
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingAdd(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely add lhs and rhs, expanding the width of the result as necessary.
llvm::APSInt safeMax(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:90
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
llvm::APSInt expandingMul(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely multiple lhs and rhs, expanding the width of the result as necessary.
bool safeGt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:74
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:66
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
bool safeLt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:58
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
bool safeLe(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:62
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
bool safeGe(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:78
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
llvm::APSInt safeMin(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:82
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
llvm::SMTExprRef getSymbol(const ConstrainRef &r)