LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
12#include "llzk/Util/Debug.h"
14
15#include <llvm/ADT/TypeSwitch.h>
16
17using namespace mlir;
18
19namespace llzk {
20
21using namespace array;
22using namespace boolean;
23using namespace component;
24using namespace constrain;
25using namespace felt;
26
27/* Field */
28
29Field::Field(std::string_view primeStr) : primeMod(llvm::APSInt(primeStr)) {
30 halfPrime = (primeMod + felt(1)) / felt(2);
31}
32
33const Field &Field::getField(const char *fieldName) {
34 static llvm::DenseMap<llvm::StringRef, Field> knownFields;
35 static std::once_flag fieldsInit;
36 std::call_once(fieldsInit, initKnownFields, knownFields);
37
38 if (auto it = knownFields.find(fieldName); it != knownFields.end()) {
39 return it->second;
40 }
41 llvm::report_fatal_error("field \"" + mlir::Twine(fieldName) + "\" is unsupported");
42}
43
44void Field::initKnownFields(llvm::DenseMap<llvm::StringRef, Field> &knownFields) {
45 // bn128/254, default for circom
46 knownFields.try_emplace(
47 "bn128",
48 Field("21888242871839275222246405745257275088696311157297823662689037894645226208583")
49 );
50 knownFields.try_emplace("bn254", knownFields.at("bn128"));
51 // 15 * 2^27 + 1, default for zirgen
52 knownFields.try_emplace("babybear", Field("2013265921"));
53 // 2^64 - 2^32 + 1, used for plonky2
54 knownFields.try_emplace("goldilocks", Field("18446744069414584321"));
55 // 2^31 - 1, used for Plonky3
56 knownFields.try_emplace("mersenne31", Field("2147483647"));
57}
58
59llvm::APSInt Field::reduce(llvm::APSInt i) const {
60 unsigned maxBits = std::max(i.getBitWidth(), bitWidth());
61 llvm::APSInt m = (i.extend(maxBits) % prime().extend(maxBits)).trunc(bitWidth());
62 if (m.isNegative()) {
63 return prime() + m;
64 }
65 return m;
66}
67
68llvm::APSInt Field::reduce(unsigned i) const {
69 auto ap = llvm::APSInt(llvm::APInt(bitWidth(), i));
70 return reduce(ap);
71}
72
73/* UnreducedInterval */
74
76 if (safeGt(a, b)) {
77 return Interval::Empty(field);
78 }
79 if (safeGe(width(), field.prime())) {
80 return Interval::Entire(field);
81 }
82 auto lhs = field.reduce(a), rhs = field.reduce(b);
83 // lhs and rhs are now guaranteed to have the same bitwidth, so we can use
84 // built-in functions.
85 if ((rhs - lhs).isZero()) {
86 return Interval::Degenerate(field, lhs);
87 }
88
89 const auto &half = field.half();
90 if (lhs.ule(rhs)) {
91 if (lhs.ult(half) && rhs.ult(half)) {
92 return Interval::TypeA(field, lhs, rhs);
93 } else if (lhs.ult(half)) {
94 return Interval::TypeC(field, lhs, rhs);
95 } else {
96 return Interval::TypeB(field, lhs, rhs);
97 }
98 } else {
99 if (lhs.uge(half) && rhs.ult(half)) {
100 return Interval::TypeF(field, lhs, rhs);
101 } else {
102 return Interval::Entire(field);
103 }
104 }
105}
106
108 auto &lhs = *this;
109 return UnreducedInterval(safeMax(lhs.a, rhs.a), safeMin(lhs.b, rhs.b));
110}
111
113 auto &lhs = *this;
114 return UnreducedInterval(safeMin(lhs.a, rhs.a), safeMax(lhs.b, rhs.b));
115}
116
118 if (isEmpty() || rhs.isEmpty()) {
119 return *this;
120 }
121 auto one = llvm::APSInt(llvm::APInt(a.getBitWidth(), 1));
122 auto bound = expandingSub(rhs.b, one);
123 return UnreducedInterval(a, safeMin(b, bound));
124}
125
127 if (isEmpty() || rhs.isEmpty()) {
128 return *this;
129 }
130 return UnreducedInterval(a, safeMin(b, rhs.b));
131}
132
134 if (isEmpty() || rhs.isEmpty()) {
135 return *this;
136 }
137 auto one = llvm::APSInt(llvm::APInt(a.getBitWidth(), 1));
138 auto bound = expandingAdd(rhs.a, one);
139 return UnreducedInterval(safeMax(a, bound), b);
140}
141
143 if (isEmpty() || rhs.isEmpty()) {
144 return *this;
145 }
146 return UnreducedInterval(safeMax(a, rhs.a), b);
147}
148
150
152 llvm::APSInt low = expandingAdd(lhs.a, rhs.a), high = expandingAdd(lhs.b, rhs.b);
153 return UnreducedInterval(low, high);
154}
155
157 return lhs + (-rhs);
158}
159
161 auto v1 = expandingMul(lhs.a, rhs.a);
162 auto v2 = expandingMul(lhs.a, rhs.b);
163 auto v3 = expandingMul(lhs.b, rhs.a);
164 auto v4 = expandingMul(lhs.b, rhs.b);
165
166 auto minVal = safeMin({v1, v2, v3, v4});
167 auto maxVal = safeMax({v1, v2, v3, v4});
168
169 return UnreducedInterval(minVal, maxVal);
170}
171
173 return isNotEmpty() && rhs.isNotEmpty() && safeGe(b, rhs.a) && safeLe(a, rhs.b);
174}
175
176std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
177 if (safeLt(lhs.a, rhs.a) || (safeEq(lhs.a, rhs.a) && safeLt(lhs.b, rhs.b))) {
178 return std::strong_ordering::less;
179 }
180 if (safeGt(lhs.a, rhs.a) || (safeEq(lhs.a, rhs.a) && safeGt(lhs.b, rhs.b))) {
181 return std::strong_ordering::greater;
182 }
183 return std::strong_ordering::equal;
184}
185
186llvm::APSInt UnreducedInterval::width() const {
187 llvm::APSInt w;
188 if (safeGt(a, b)) {
189 // This would be reduced to an empty Interval, so the width is just zero.
190 w = llvm::APSInt::getUnsigned(0);
191 } else {
193 w = expandingSub(b, a)++;
194 }
195 ensure(safeGe(w, llvm::APSInt::getUnsigned(0)), "cannot have negative width");
196 return w;
197}
198
199bool UnreducedInterval::isEmpty() const { return safeEq(width(), llvm::APSInt::getUnsigned(0)); }
200
201/* Interval */
202
204 if (isEmpty()) {
205 return UnreducedInterval(field.get().zero(), field.get().zero());
206 }
207 if (isEntire()) {
208 return UnreducedInterval(field.get().zero(), field.get().maxVal());
209 }
210 return UnreducedInterval(a, b);
211}
212
214 if (is<Type::TypeF>()) {
215 return UnreducedInterval(field.get().prime() - a, b);
216 }
217 return toUnreduced();
218}
219
221 ensure(is<Type::TypeA, Type::TypeB, Type::TypeC>(), "unsupported range type");
222 return UnreducedInterval(a - field.get().prime(), b - field.get().prime());
223}
224
226 auto &lhs = *this;
227 ensure(
228 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
229 );
230 const Field &f = lhs.getField();
231
232 // Trivial cases
233 if (lhs.isEntire() || rhs.isEntire()) {
234 return Interval::Entire(f);
235 }
236 if (lhs.isEmpty()) {
237 return rhs;
238 }
239 if (rhs.isEmpty()) {
240 return lhs;
241 }
242 if (lhs.isDegenerate() || rhs.isDegenerate()) {
243 return lhs.toUnreduced().doUnion(rhs.toUnreduced()).reduce(f);
244 }
245
246 // More complex cases
247 if (areOneOf<
250 return Interval(rhs.ty, f, std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
251 }
253 auto lhsUnred = lhs.firstUnreduced();
254 auto opt1 = rhs.firstUnreduced().doUnion(lhsUnred);
255 auto opt2 = rhs.secondUnreduced().doUnion(lhsUnred);
256 if (opt1.width() <= opt2.width()) {
257 return opt1.reduce(f);
258 }
259 return opt2.reduce(f);
260 }
262 return lhs.firstUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
263 }
265 return lhs.secondUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
266 }
268 return Interval::Entire(f);
269 }
270 if (areOneOf<
273 lhs, rhs
274 )) {
275 return rhs.join(lhs);
276 }
277 llvm::report_fatal_error("unhandled join case");
278 return Interval::Entire(f);
279}
280
282 auto &lhs = *this;
283 ensure(
284 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
285 );
286
287 // Trivial cases
288 if (lhs.isEmpty() || rhs.isEmpty()) {
289 return Interval::Empty(field.get());
290 }
291 if (lhs.isEntire()) {
292 return rhs;
293 }
294 if (rhs.isEntire()) {
295 return lhs;
296 }
297 if (lhs.isDegenerate() || rhs.isDegenerate()) {
298 return lhs.toUnreduced().intersect(rhs.toUnreduced()).reduce(field.get());
299 }
300
301 // More complex cases
302 if (areOneOf<
305 auto maxA = std::max(lhs.a, rhs.a);
306 auto minB = std::min(lhs.b, rhs.b);
307 if (maxA <= minB) {
308 return Interval(lhs.ty, field.get(), maxA, minB);
309 } else {
310 return Interval::Empty(field.get());
311 }
312 }
314 return Interval::Empty(field.get());
315 }
317 return lhs.firstUnreduced().intersect(rhs.firstUnreduced()).reduce(field.get());
318 }
320 return lhs.secondUnreduced().intersect(rhs.firstUnreduced()).reduce(field.get());
321 }
323 auto rhsUnred = rhs.firstUnreduced();
324 auto opt1 = lhs.firstUnreduced().intersect(rhsUnred).reduce(field.get());
325 auto opt2 = lhs.secondUnreduced().intersect(rhsUnred).reduce(field.get());
326 ensure(!opt1.isEntire() && !opt2.isEntire(), "impossible intersection");
327 if (opt1.isEmpty()) {
328 return opt2;
329 }
330 if (opt2.isEmpty()) {
331 return opt1;
332 }
333 return opt1.join(opt2);
334 }
335 if (areOneOf<
338 lhs, rhs
339 )) {
340 return rhs.intersect(lhs);
341 }
342 return Interval::Empty(field.get());
343}
344
346 // intersect checks that we're in the same field
348 if (intersection.isEmpty()) {
349 // There's nothing to remove, so just return this
350 return *this;
351 }
352
353 const Field &f = field.get();
354
355 // Trivial cases with a non-empty intersection
356 if (isDegenerate() || other.isEntire()) {
357 return Interval::Empty(f);
358 }
359 if (isEntire()) {
360 // Since we don't support punching arbitrary holes in ranges, we only reduce
361 // entire ranges if other is [0, b] or [a, prime - 1]
362 if (other.a == f.zero()) {
363 return UnreducedInterval(other.b + f.one(), f.maxVal()).reduce(f);
364 }
365 if (other.b == f.maxVal()) {
366 return UnreducedInterval(f.zero(), other.a - f.one()).reduce(f);
367 }
368
369 return *this;
370 }
371
372 // Non-trivial cases
373 // - Internal+internal or external+external cases
376 areOneOf<{Type::TypeF, Type::TypeF}>(*this, intersection)) {
377 // The intersection needs to be at the end of the interval, otherwise we would
378 // split the interval in two, and we aren't set up to support multiple intervals
379 // per value.
380 if (a != intersection.a && b != intersection.b) {
381 return *this;
382 }
383 // Otherwise, remove the intersection and reduce
384 if (a == intersection.a) {
385 return UnreducedInterval(intersection.b + f.one(), b).reduce(f);
386 }
387 // else b == intersection.b
388 return UnreducedInterval(a, intersection.a - f.one()).reduce(f);
389 }
390 // - Mixed internal/external cases. We flip the comparison
391 if (isTypeF()) {
392 if (a != intersection.b && b != intersection.a) {
393 return *this;
394 }
395 // Otherwise, remove the intersection and reduce
396 if (a == intersection.b) {
397 return UnreducedInterval(intersection.a + f.one(), b).reduce(f);
398 }
399 // else b == intersection.a
400 return UnreducedInterval(a, intersection.b - f.one()).reduce(f);
401 }
402
403 // In cases we don't know how to handle, we over-approximate and return
404 // the original interval.
405 return *this;
406}
407
408Interval Interval::operator-() const { return (-firstUnreduced()).reduce(field.get()); }
409
411 ensure(lhs.field.get() == rhs.field.get(), "cannot add intervals in different fields");
412 return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(lhs.field.get());
413}
414
415Interval operator-(const Interval &lhs, const Interval &rhs) { return lhs + (-rhs); }
416
418 ensure(lhs.field.get() == rhs.field.get(), "cannot multiply intervals in different fields");
419 const auto &field = lhs.field.get();
420 auto zeroInterval = Interval::Degenerate(field, field.zero());
421 if (lhs == zeroInterval || rhs == zeroInterval) {
422 return zeroInterval;
423 }
424 if (lhs.isEmpty() || rhs.isEmpty()) {
425 return Interval::Empty(field);
426 }
427 if (lhs.isEntire() || rhs.isEntire()) {
428 return Interval::Entire(field);
429 }
430
432 return (lhs.secondUnreduced() * rhs.secondUnreduced()).reduce(field);
433 }
434 return (lhs.firstUnreduced() * rhs.firstUnreduced()).reduce(field);
435}
436
437FailureOr<Interval> operator/(const Interval &lhs, const Interval &rhs) {
438 ensure(lhs.getField() == rhs.getField(), "cannot divide intervals in different fields");
439 const auto &field = rhs.getField();
440 if (rhs.width() > field.one()) {
441 return Interval::Entire(field);
442 }
443 if (rhs.a.isZero()) {
444 return failure();
445 }
446 return success(UnreducedInterval(lhs.a / rhs.a, lhs.b / rhs.a).reduce(field));
447}
448
450 ensure(
451 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
452 );
453 const auto &field = rhs.getField();
454 return UnreducedInterval(field.zero(), rhs.b).reduce(field);
455}
456
457void Interval::print(mlir::raw_ostream &os) const {
458 os << TypeName(ty);
459 if (is<Type::Degenerate>()) {
460 os << '(' << a << ')';
461 } else if (!is<Type::Entire, Type::Empty>()) {
462 os << ":[ " << a << ", " << b << " ]";
463 }
464}
465
466/* ExpressionValue */
467
469 if (expr == nullptr && rhs.expr == nullptr) {
470 return i == rhs.i;
471 }
472 if (expr == nullptr || rhs.expr == nullptr) {
473 return false;
474 }
475 return i == rhs.i && *expr == *rhs.expr;
476}
477
479intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
480 Interval res = lhs.i.intersect(rhs.i);
481 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
482 return ExpressionValue(exprEq, res);
483}
484
486add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
487 ExpressionValue res;
488 res.i = lhs.i + rhs.i;
489 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
490 return res;
491}
492
494sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
495 ExpressionValue res;
496 res.i = lhs.i - rhs.i;
497 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
498 return res;
499}
500
502mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
503 ExpressionValue res;
504 res.i = lhs.i * rhs.i;
505 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
506 return res;
507}
508
510div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs,
511 const ExpressionValue &rhs) {
512 ExpressionValue res;
513 auto divRes = lhs.i / rhs.i;
514 if (failed(divRes)) {
515 op->emitWarning(
516 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
517 " Range of division result will be treated as unbounded."
518 );
519 res.i = Interval::Entire(lhs.getField());
520 } else {
521 res.i = *divRes;
522 }
523 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
524 return res;
525}
526
528mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
529 ExpressionValue res;
530 res.i = lhs.i % rhs.i;
531 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
532 return res;
533}
534
536cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs) {
537 ExpressionValue res;
538 res.i = Interval::Boolean(lhs.getField());
539 switch (op.getPredicate()) {
540 case FeltCmpPredicate::EQ:
541 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
542 res.i = lhs.i.intersect(rhs.i);
543 break;
544 case FeltCmpPredicate::NE:
545 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
546 break;
547 case FeltCmpPredicate::LT:
548 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
549 break;
550 case FeltCmpPredicate::LE:
551 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
552 break;
553 case FeltCmpPredicate::GT:
554 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
555 break;
556 case FeltCmpPredicate::GE:
557 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
558 break;
559 }
560 return res;
561}
562
564 llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs
565) {
566 ExpressionValue res;
567 res.i = Interval::Entire(lhs.getField());
568 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
569 .Case<AndFeltOp>([&](AndFeltOp _) { return solver->mkBVAnd(lhs.expr, rhs.expr); })
570 .Case<OrFeltOp>([&](OrFeltOp _) { return solver->mkBVOr(lhs.expr, rhs.expr); })
571 .Case<XorFeltOp>([&](XorFeltOp _) { return solver->mkBVXor(lhs.expr, rhs.expr); })
572 .Case<ShlFeltOp>([&](ShlFeltOp _) { return solver->mkBVShl(lhs.expr, rhs.expr); })
573 .Case<ShrFeltOp>([&](ShrFeltOp _) {
574 return solver->mkBVLshr(lhs.expr, rhs.expr);
575 }).Default([&](Operation *unsupported) {
576 llvm::report_fatal_error(
577 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
578 );
579 return nullptr;
580 });
581
582 return res;
583}
584
585ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val) {
586 ExpressionValue res;
587 res.i = -val.i;
588 res.expr = solver->mkBVNeg(val.expr);
589 return res;
590}
591
592ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val) {
593 ExpressionValue res;
594 const auto &f = val.getField();
595 if (val.i.isDegenerate()) {
596 if (val.i == Interval::Degenerate(f, f.zero())) {
597 res.i = Interval::Degenerate(f, f.one());
598 } else {
599 res.i = Interval::Degenerate(f, f.zero());
600 }
601 }
602 res.i = Interval::Boolean(f);
603 res.expr = solver->mkBVNot(val.expr);
604 return res;
605}
606
608fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val) {
609 const Field &field = val.getField();
610 ExpressionValue res;
611 res.i = Interval::Entire(field);
612 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
613 .Case<InvFeltOp>([&](InvFeltOp _) {
614 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
615 // To create this expression, we create a new symbol for Y and add the
616 // XY % prime = 1 constraint to the solver.
617 std::string symName = buildStringViaInsertionOp(*op);
618 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
619 llvm::SMTExprRef one = solver->mkBitvector(field.one(), field.bitWidth());
620 llvm::SMTExprRef prime = solver->mkBitvector(field.prime(), field.bitWidth());
621 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
622 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
623 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
624 solver->addConstraint(constraint);
625 return invSym;
626 }).Default([&](Operation *unsupported) {
627 llvm::report_fatal_error(
628 "no fallback provided for " + mlir::Twine(op->getName().getStringRef())
629 );
630 return nullptr;
631 });
632
633 return res;
634}
635
636void ExpressionValue::print(mlir::raw_ostream &os) const {
637 if (expr) {
638 expr->print(os);
639 } else {
640 os << "<null expression>";
641 }
642
643 os << " ( interval: " << i << " )";
644}
645
646/* IntervalAnalysisLattice */
647
648ChangeResult IntervalAnalysisLattice::join(const AbstractDenseLattice &other) {
649 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
650 if (!rhs) {
651 llvm::report_fatal_error("invalid join lattice type");
652 }
653 ChangeResult res = ChangeResult::NoChange;
654 for (auto &[k, v] : rhs->valMap) {
655 auto it = valMap.find(k);
656 if (it == valMap.end() || it->second != v) {
657 valMap[k] = v;
658 res |= ChangeResult::Change;
659 }
660 }
661 for (auto &v : rhs->constraints) {
662 if (!constraints.contains(v)) {
663 constraints.insert(v);
664 res |= ChangeResult::Change;
665 }
666 }
667 for (auto &[e, i] : rhs->intervals) {
668 auto it = intervals.find(e);
669 if (it == intervals.end() || it->second != i) {
670 intervals[e] = i;
671 res |= ChangeResult::Change;
672 }
673 }
674 return res;
675}
676
677void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
678 os << "IntervalAnalysisLattice { ";
679 for (auto &[ref, val] : valMap) {
680 os << "\n (valMap) " << ref << " := " << val;
681 }
682 for (auto &[expr, interval] : intervals) {
683 os << "\n (intervals) ";
684 expr->print(os);
685 os << " in " << interval;
686 }
687 if (!valMap.empty()) {
688 os << '\n';
689 }
690 os << '}';
691}
692
693FailureOr<IntervalAnalysisLattice::LatticeValue> IntervalAnalysisLattice::getValue(Value v) const {
694 auto it = valMap.find(v);
695 if (it == valMap.end()) {
696 return failure();
697 }
698 return it->second;
699}
700
702 LatticeValue val(e);
703 if (valMap[v] == val) {
704 return ChangeResult::NoChange;
705 }
706 valMap[v] = val;
707 intervals[e.getExpr()] = e.getInterval();
708 return ChangeResult::Change;
709}
710
712 if (!constraints.contains(e)) {
713 constraints.insert(e);
714 return ChangeResult::Change;
715 }
716 return ChangeResult::NoChange;
717}
718
719FailureOr<Interval> IntervalAnalysisLattice::findInterval(llvm::SMTExprRef expr) const {
720 auto it = intervals.find(expr);
721 if (it != intervals.end()) {
722 return it->second;
723 }
724 return failure();
725}
726
727/* IntervalDataFlowAnalysis */
728
733 CallOpInterface call, dataflow::CallControlFlowAction action,
735) {
739 if (action == dataflow::CallControlFlowAction::EnterCallee) {
740 // We skip updating the incoming lattice for function calls,
741 // as values are relative to the containing function/struct, so we don't need to pollute
742 // the callee with the callers values.
743 setToEntryState(after);
744 }
748 else if (action == dataflow::CallControlFlowAction::ExitCallee) {
749 // Get the argument values of the lattice by getting the state as it would
750 // have been for the callsite.
751 dataflow::AbstractDenseLattice *beforeCall = nullptr;
752 if (auto *prev = call->getPrevNode()) {
753 beforeCall = getLattice(prev);
754 } else {
755 beforeCall = getLattice(call->getBlock());
756 }
757 ensure(beforeCall, "could not get prior lattice");
758
759 // The lattice at the return is the lattice before the call
760 propagateIfChanged(after, after->join(*beforeCall));
761 }
766 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
767 // For external calls, we propagate what information we already have from
768 // before the call to after the call, since the external call won't invalidate
769 // any of that information. It also, conservatively, makes no assumptions about
770 // external calls and their computation, so CDG edges will not be computed over
771 // input arguments to external functions.
772 join(after, before);
773 }
774}
775
777 Operation *op, const Lattice &before, Lattice *after
778) {
779 ChangeResult changed = after->join(before);
780
781 llvm::SmallVector<LatticeValue> operandVals;
782
783 auto constrainRefLattice = dataflowSolver.lookupState<ConstrainRefLattice>(op);
784 ensure(constrainRefLattice, "failed to get lattice");
785
786 for (OpOperand &operand : op->getOpOperands()) {
787 Value val = operand.get();
788 // First, lookup the operand value in the before state.
789 auto priorState = before.getValue(val);
790 if (succeeded(priorState)) {
791 operandVals.push_back(*priorState);
792 continue;
793 }
794
795 // Else, look up the stored value by constrain ref.
796 // We only care about scalar type values, so we ignore composite types, which
797 // are currently limited to non-Signal structs and arrays.
798 Type valTy = val.getType();
799 if (mlir::isa<ArrayType, StructType>(valTy) && !isSignalType(valTy)) {
800 operandVals.push_back(LatticeValue());
801 continue;
802 }
803
804 ConstrainRefLatticeValue refSet = constrainRefLattice->getOrDefault(val);
805 ensure(refSet.isScalar(), "should have ruled out array values already");
806
807 if (refSet.getScalarValue().empty()) {
808 // If we can't compute the reference, then there must be some unsupported
809 // op the reference analysis cannot handle. We emit a warning and return
810 // early, since there's no meaningful computation we can do for this op.
811 op->emitWarning() << "state of " << val
812 << " is empty; defining operation is unsupported by constrain ref analysis";
813 propagateIfChanged(after, changed);
814 return;
815 } else if (!refSet.isSingleValue()) {
816 std::string warning;
817 debug::Appender(warning) << "operand " << val << " is not a single value " << refSet
818 << ", overapproximating";
819 op->emitWarning(warning);
820 // Here, we will override the prior lattice value with a new symbol, representing
821 // "any" value, then use that value for the operands.
822 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
823 changed |= after->setValue(val, anyVal);
824 operandVals.emplace_back(anyVal);
825 } else {
826 auto ref = refSet.getSingleValue();
827 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
828 changed |= after->setValue(val, exprVal);
829 operandVals.emplace_back(exprVal);
830 }
831 }
832
833 // Now, the way we update is dependent on the type of the operation.
834 if (!isConsideredOp(op)) {
835 op->emitWarning("unconsidered operation type, analysis may be incomplete");
836 }
837
838 if (isConstOp(op)) {
839 auto constVal = getConst(op);
840 auto expr = createConstBitvectorExpr(constVal);
841 ExpressionValue latticeVal(field.get(), expr, constVal);
842 changed |= after->setValue(op->getResult(0), latticeVal);
843 } else if (isArithmeticOp(op)) {
844 ensure(operandVals.size() <= 2, "arithmetic op with the wrong number of operands");
845 ExpressionValue result;
846 if (operandVals.size() == 2) {
847 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
848 } else {
849 result = performUnaryArithmetic(op, operandVals[0]);
850 }
851
852 changed |= after->setValue(op->getResult(0), result);
853 } else if (EmitEqualityOp emitEq = mlir::dyn_cast<EmitEqualityOp>(op)) {
854 ensure(operandVals.size() == 2, "constraint op with the wrong number of operands");
855 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
856 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
857 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
858
859 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
860 // Update the LHS and RHS to the same value, but restricted intervals
861 // based on the constraints
862 changed |= applyInterval(emitEq, after, lhsVal, constraint.getInterval());
863 changed |= applyInterval(emitEq, after, rhsVal, constraint.getInterval());
864 changed |= after->addSolverConstraint(constraint);
865 } else if (AssertOp assertOp = mlir::dyn_cast<AssertOp>(op)) {
866 ensure(operandVals.size() == 1, "assert op with the wrong number of operands");
867 // assert enforces that the operand is true. So we apply an interval of [1, 1]
868 // to the operand.
869 changed |= applyInterval(
870 assertOp, after, assertOp.getCondition(),
871 Interval::Degenerate(field.get(), field.get().one())
872 );
873 // Also add the solver constraint that the expression must be true.
874 auto assertExpr = operandVals[0].getScalarValue();
875 changed |= after->addSolverConstraint(assertExpr);
876 } else if (auto readf = mlir::dyn_cast<FieldReadOp>(op);
877 readf && isSignalType(readf.getComponent().getType())) {
878 // The reg value read from the signal type is equal to the value of the Signal
879 // struct overall.
880 changed |= after->setValue(readf.getVal(), operandVals[0].getScalarValue());
881 } else if (
882 // We do not need to explicitly handle read ops since they are resolved at the operand value
883 // step where constrain refs are queries (with the exception of the Signal struct, see above).
884 !isReadOp(op)
885 // We do not currently handle return ops as the analysis is currently limited to constrain
886 // functions, which return no value.
887 && !isReturnOp(op)
888 // The analysis ignores definition ops.
889 && !isDefinitionOp(op)
890 // We do not need to analyze the creation of structs.
891 && !mlir::isa<CreateStructOp>(op)
892 ) {
893 op->emitWarning("unhandled operation, analysis may be incomplete");
894 }
895
896 propagateIfChanged(after, changed);
897}
898
900 auto it = refSymbols.find(r);
901 if (it != refSymbols.end()) {
902 return it->second;
903 }
904 auto sym = createFeltSymbol(r);
905 refSymbols[r] = sym;
906 return sym;
907}
908
909llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const ConstrainRef &r) const {
910 return createFeltSymbol(buildStringViaPrint(r).c_str());
911}
912
913llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v) const {
914 return createFeltSymbol(buildStringViaPrint(v).c_str());
915}
916
917llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) const {
918 return field.get().createSymbol(smtSolver, name);
919}
920
921llvm::APSInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
922 ensure(isConstOp(op), "op is not a const op");
923
924 llvm::APInt fieldConst =
925 TypeSwitch<Operation *, llvm::APInt>(op)
926 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
927 llvm::APSInt constOpVal(feltConst.getValueAttr().getValue());
928 return field.get().reduce(constOpVal);
929 })
930 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
931 return llvm::APInt(field.get().bitWidth(), indexConst.value());
932 })
933 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
934 return llvm::APInt(field.get().bitWidth(), intConst.value());
935 }).Default([](Operation *illegalOp) {
936 std::string err;
937 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
938 llvm::report_fatal_error(Twine(err));
939 return llvm::APInt();
940 });
941 return llvm::APSInt(fieldConst);
942}
943
944ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
945 Operation *op, const LatticeValue &a, const LatticeValue &b
946) {
947 ensure(isArithmeticOp(op), "is not arithmetic op");
948
949 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
950 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
951 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
952
953 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
954 .Case<AddFeltOp>([&](AddFeltOp _) { return add(smtSolver, lhs, rhs); })
955 .Case<SubFeltOp>([&](SubFeltOp _) { return sub(smtSolver, lhs, rhs); })
956 .Case<MulFeltOp>([&](MulFeltOp _) { return mul(smtSolver, lhs, rhs); })
957 .Case<DivFeltOp>([&](DivFeltOp divOp) { return div(smtSolver, divOp, lhs, rhs); })
958 .Case<ModFeltOp>([&](ModFeltOp _) { return mod(smtSolver, lhs, rhs); })
959 .Case<CmpOp>([&](CmpOp cmpOp) {
960 return cmp(smtSolver, cmpOp, lhs, rhs);
961 }).Default([&](Operation *unsupported) {
962 unsupported->emitWarning(
963 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
964 );
965 return fallbackBinaryOp(smtSolver, unsupported, lhs, rhs);
966 });
967
968 ensure(res.getExpr(), "arithmetic produced null smt expr");
969 return res;
970}
971
972ExpressionValue
973IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
974 ensure(isArithmeticOp(op), "is not arithmetic op");
975
976 auto val = a.getScalarValue();
977 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
978
979 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
980 .Case<NegFeltOp>([&](NegFeltOp _) { return neg(smtSolver, val); })
981 .Case<NotFeltOp>([&](NotFeltOp _) {
982 return notOp(smtSolver, val);
983 }).Default([&](Operation *unsupported) {
984 unsupported->emitWarning(
985 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
986 );
987 return fallbackUnaryOp(smtSolver, unsupported, val);
988 });
989
990 ensure(res.getExpr(), "arithmetic produced null smt expr");
991 return res;
992}
993
994ChangeResult IntervalDataFlowAnalysis::applyInterval(
995 Operation *originalOp, Lattice *after, Value val, Interval newInterval
996) {
997 auto latValRes = after->getValue(val);
998 if (failed(latValRes)) {
999 // visitOperation didn't add val to the lattice, so there's nothing to do
1000 return ChangeResult::NoChange;
1001 }
1002 ExpressionValue newLatticeVal = latValRes->getScalarValue().withInterval(newInterval);
1003 ChangeResult res = after->setValue(val, newLatticeVal);
1004 // To allow the dataflow analysis to do its fixed-point iteration, we need to
1005 // add the new expression to val's lattice as well.
1006 Lattice *valLattice = nullptr;
1007 if (auto valOp = val.getDefiningOp()) {
1008 // Getting the lattice at valOp gives us the "after" lattice, but we want to
1009 // update the "before" lattice so that the inputs to visitOperation will be
1010 // changed.
1011 if (auto prev = valOp->getPrevNode()) {
1012 valLattice = getOrCreate<Lattice>(prev);
1013 } else {
1014 valLattice = getOrCreate<Lattice>(valOp->getBlock());
1015 }
1016 } else if (auto blockArg = mlir::dyn_cast<BlockArgument>(val)) {
1017 valLattice = getOrCreate<Lattice>(blockArg.getOwner());
1018 } else {
1019 valLattice = getOrCreate<Lattice>(val);
1020 }
1021 ensure(valLattice, "val should have a lattice");
1022 if (valLattice != after) {
1023 propagateIfChanged(valLattice, valLattice->setValue(val, newLatticeVal));
1024 }
1025
1026 // Now we descend into val's operands, if it has any.
1027 Operation *definingOp = val.getDefiningOp();
1028 if (!definingOp) {
1029 return res;
1030 }
1031
1032 const Field &f = field.get();
1033
1034 // This is a rules-based operation. If we have a rule for a given operation,
1035 // then we can make some kind of update, otherwise we leave the intervals
1036 // as is.
1037 // - First we'll define all the rules so the type switch can be less messy
1038
1039 // cmp.<pred> restricts each side of the comparison if the result is known.
1040 auto cmpCase = [&](CmpOp cmpOp) {
1041 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
1042 // either "true" (1) or "false" (0)
1043 Interval maxInterval = Interval::Boolean(f);
1044 ensure(
1045 newInterval.intersect(maxInterval).isNotEmpty(),
1046 "new interval for CmpOp outside of allowed boolean range or is empty"
1047 );
1048 if (!newInterval.isDegenerate()) {
1049 // The comparison result is unknown, so we can't update the operand ranges
1050 return ChangeResult::NoChange;
1051 }
1052
1053 bool cmpTrue = newInterval.rhs() == f.one();
1054
1055 Value lhs = cmpOp->getOperand(0), rhs = cmpOp->getOperand(1);
1056 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
1057 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
1058 return ChangeResult::NoChange;
1059 }
1060 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
1061 rhsExpr = rhsLatValRes->getScalarValue();
1062
1063 Interval newLhsInterval, newRhsInterval;
1064 const Interval &lhsInterval = lhsExpr.getInterval();
1065 const Interval &rhsInterval = rhsExpr.getInterval();
1066
1067 FeltCmpPredicate pred = cmpOp.getPredicate();
1068 // predicate cases
1069 auto eqCase = [&]() {
1070 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1071 (pred == FeltCmpPredicate::NE && !cmpTrue);
1072 };
1073 auto neCase = [&]() {
1074 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1075 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1076 };
1077 auto ltCase = [&]() {
1078 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1079 (pred == FeltCmpPredicate::GE && !cmpTrue);
1080 };
1081 auto leCase = [&]() {
1082 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1083 (pred == FeltCmpPredicate::GT && !cmpTrue);
1084 };
1085 auto gtCase = [&]() {
1086 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1087 (pred == FeltCmpPredicate::LE && !cmpTrue);
1088 };
1089 auto geCase = [&]() {
1090 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1091 (pred == FeltCmpPredicate::LT && !cmpTrue);
1092 };
1093
1094 // new intervals based on case
1095 if (eqCase()) {
1096 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1097 } else if (neCase()) {
1098 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1099 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
1100 // an empty value range.
1101 newLhsInterval = newRhsInterval = Interval::Empty(f);
1102 } else {
1103 // Leave unchanged
1104 newLhsInterval = lhsInterval;
1105 newRhsInterval = rhsInterval;
1106 }
1107 } else if (ltCase()) {
1108 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1109 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1110 } else if (leCase()) {
1111 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1112 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1113 } else if (gtCase()) {
1114 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1115 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1116 } else if (geCase()) {
1117 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1118 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1119 } else {
1120 cmpOp->emitWarning("unhandled cmp predicate");
1121 return ChangeResult::NoChange;
1122 }
1123
1124 // Now we recurse to each operand
1125 return applyInterval(originalOp, after, lhs, newLhsInterval) |
1126 applyInterval(originalOp, after, rhs, newRhsInterval);
1127 };
1128
1129 // If the result of a multiplication is non-zero, then both operands must be
1130 // non-zero.
1131 auto mulCase = [&](MulFeltOp mulOp) {
1132 auto zeroInt = Interval::Degenerate(f, f.zero());
1133 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1134 // The multiplication may be zero, so we can't reduce the operands to be non-zero
1135 return ChangeResult::NoChange;
1136 }
1137
1138 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
1139 auto lhsLatValRes = after->getValue(lhs), rhsLatValRes = after->getValue(rhs);
1140 if (failed(lhsLatValRes) || failed(rhsLatValRes)) {
1141 return ChangeResult::NoChange;
1142 }
1143 ExpressionValue lhsExpr = lhsLatValRes->getScalarValue(),
1144 rhsExpr = rhsLatValRes->getScalarValue();
1145 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1146 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1147 return applyInterval(originalOp, after, lhs, newLhsInterval) |
1148 applyInterval(originalOp, after, rhs, newRhsInterval);
1149 };
1150
1151 // We have a special case for the Signal struct: if this value is created
1152 // from reading a Signal struct's reg field, we also apply the interval to
1153 // the struct itself.
1154 auto readfCase = [&](FieldReadOp readfOp) {
1155 Value comp = readfOp.getComponent();
1156 if (isSignalType(comp.getType())) {
1157 return applyInterval(originalOp, after, comp, newInterval);
1158 }
1159 return ChangeResult::NoChange;
1160 };
1161
1162 // - Apply the rules given the op.
1163 // NOTE: disabling clang-format for this because it makes the last case statement
1164 // look ugly.
1165 // clang-format off
1166 res |= TypeSwitch<Operation *, ChangeResult>(definingOp)
1167 .Case<CmpOp>([&](CmpOp op) { return cmpCase(op); })
1168 .Case<MulFeltOp>([&](MulFeltOp op) { return mulCase(op); })
1169 .Case<FieldReadOp>([&](FieldReadOp op){ return readfCase(op); })
1170 .Default([&](Operation *_) { return ChangeResult::NoChange; });
1171 // clang-format on
1172
1173 return res;
1174}
1175
1176/* StructIntervals */
1177
1179 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx
1180) {
1181 // Get the lattice at the end of the constrain function.
1182 function::ReturnOp constrainEnd;
1183 structDef.getConstrainFuncOp().walk([&constrainEnd](function::ReturnOp r) mutable {
1184 constrainEnd = r;
1185 });
1186
1187 const IntervalAnalysisLattice *constrainLattice =
1188 solver.lookupState<IntervalAnalysisLattice>(constrainEnd);
1189
1190 constrainSolverConstraints = constrainLattice->getConstraints();
1191
1192 for (const auto &ref : ConstrainRef::getAllConstrainRefs(structDef)) {
1193 // We only want to compute intervals for field elements and not composite types,
1194 // with the exception of the Signal struct.
1195 if (!ref.isScalar() && !ref.isSignal()) {
1196 continue;
1197 }
1198 // We also don't want to show the interval for a Signal and its internal reg.
1199 if (auto parentOr = ref.getParentPrefix(); succeeded(parentOr) && parentOr->isSignal()) {
1200 continue;
1201 }
1202 auto symbol = ctx.getSymbol(ref);
1203 auto constrainInterval = constrainLattice->findInterval(symbol);
1204 if (succeeded(constrainInterval)) {
1205 constrainFieldRanges[ref] = *constrainInterval;
1206 } else {
1207 constrainFieldRanges[ref] = Interval::Entire(ctx.field);
1208 }
1209 }
1210
1211 return success();
1212}
1213
1214void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints) const {
1215 os << "StructIntervals { ";
1216 if (constrainFieldRanges.empty()) {
1217 os << "}\n";
1218 return;
1219 }
1220
1221 for (auto &[ref, interval] : constrainFieldRanges) {
1222 os << "\n " << ref << " in " << interval;
1223 }
1224
1225 if (withConstraints) {
1226 os << "\n\n Solver Constraints { ";
1227 if (constrainSolverConstraints.empty()) {
1228 os << "}\n";
1229 } else {
1230 for (const auto &e : constrainSolverConstraints) {
1231 os << "\n ";
1232 e.getExpr()->print(os);
1233 }
1234 os << "\n }";
1235 }
1236 }
1237
1238 os << "\n}\n";
1239}
1240
1241} // namespace llzk
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
MlirStringRef name
A value at a given point of the ConstrainRefLattice.
const ConstrainRef & getSingleValue() const
A lattice for use in dense analysis.
Defines a reference to a llzk object within a constrain function call.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
friend ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
const Interval & getInterval() const
friend ExpressionValue cmp(llvm::SMTSolverRef solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue div(llvm::SMTSolverRef solver, felt::DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
const Field & getField() const
friend ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
Information about the prime finite field used for the interval analysis.
llvm::APSInt one() const
Returns 1 at the bitwidth of the field.
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::APSInt zero() const
Returns 0 at the bitwidth of the field.
Field()=delete
llvm::APSInt half() const
Returns p / 2.
unsigned bitWidth() const
llvm::APSInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
static const Field & getField(const char *fieldName)
Get a Field from a given field name string.
llvm::APSInt prime() const
For the prime field p, returns p.
llvm::APSInt reduce(llvm::APSInt i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
Maps mlir::Values to LatticeValues.
mlir::ChangeResult setValue(mlir::Value v, ExpressionValue e)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
const ConstraintSet & getConstraints() const
void print(mlir::raw_ostream &os) const override
mlir::FailureOr< LatticeValue > getValue(mlir::Value v) const
mlir::ChangeResult join(const AbstractDenseLattice &other) override
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
void visitOperation(mlir::Operation *op, const Lattice &before, Lattice *after) override
Visit an operation with the dense lattice before its execution.
void visitCallControlFlowTransfer(mlir::CallOpInterface call, dataflow::CallControlFlowAction action, const Lattice &before, Lattice *after) override
The interval analysis is intraprocedural only for now, so this control flow transfer function passes ...
llvm::SMTExprRef getOrCreateSymbol(const ConstrainRef &r)
Either return the existing SMT expression that corresponds to the ConstrainRef, or create one.
Intervals over a finite field.
bool isEmpty() const
Interval intersect(const Interval &rhs) const
Intersect.
static std::string_view TypeName(Type t)
llvm::APSInt rhs() const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
bool isDegenerate() const
void print(mlir::raw_ostream &os) const
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
static Interval TypeC(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeF(const Field &f, llvm::APSInt a, llvm::APSInt b)
static Interval TypeB(const Field &f, llvm::APSInt a, llvm::APSInt b)
bool isTypeF() const
friend Interval operator*(const Interval &lhs, const Interval &rhs)
llvm::APSInt lhs() const
static Interval Empty(const Field &f)
static bool areOneOf(const Interval &a, const Interval &b)
static Interval Degenerate(const Field &f, llvm::APSInt val)
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
friend mlir::FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Returns failure if a division-by-zero is encountered.
static Interval TypeA(const Field &f, llvm::APSInt a, llvm::APSInt b)
friend Interval operator+(const Interval &lhs, const Interval &rhs)
bool isEntire() const
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
friend Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, IntervalAnalysisContext &ctx)
void print(mlir::raw_ostream &os, bool withConstraints=false) const
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
UnreducedInterval operator-() const
friend UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
bool isEmpty() const
Returns true iff width() is zero.
UnreducedInterval(llvm::APSInt x, llvm::APSInt y)
llvm::APSInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
friend std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
bool overlaps(const UnreducedInterval &rhs) const
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
friend UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:777
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs)
Join a lattice with another and propagate an update if it changed.
IntervalAnalysisLattice * getLattice(mlir::ProgramPoint point) override
mlir::dataflow::AbstractDenseLattice AbstractDenseLattice
mlir::dataflow::CallControlFlowAction CallControlFlowAction
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingSub(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely subtract lhs and rhs, expanding the width of the result as necessary.
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingAdd(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely add lhs and rhs, expanding the width of the result as necessary.
llvm::APSInt safeMax(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:90
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
llvm::APSInt expandingMul(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely multiple lhs and rhs, expanding the width of the result as necessary.
bool safeGt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:74
bool safeEq(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:66
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
bool safeLt(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:58
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
bool isSignalType(Type type)
bool safeLe(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:62
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
bool safeGe(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:78
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
llvm::APSInt safeMin(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Definition APIntHelper.h:82
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
std::reference_wrapper< const Field > field
llvm::SMTExprRef getSymbol(const ConstrainRef &r)