LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Intervals.cpp
Go to the documentation of this file.
1//===-- Intervals.cpp ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13
14using namespace mlir;
15
16namespace llzk {
17
18/* UnreducedInterval */
19
21 if (a > b) {
22 return Interval::Empty(field);
23 }
24 if (width() >= field.prime()) {
25 return Interval::Entire(field);
26 }
27 auto lhs = field.reduce(a), rhs = field.reduce(b);
28 if (rhs == lhs) {
29 return Interval::Degenerate(field, lhs);
30 }
31
32 const auto &half = field.half();
33 if (lhs <= rhs) {
34 if (lhs < half && rhs < half) {
35 return Interval::TypeA(field, lhs, rhs);
36 } else if (lhs < half) {
37 return Interval::TypeC(field, lhs, rhs);
38 } else {
39 return Interval::TypeB(field, lhs, rhs);
40 }
41 } else {
42 if (lhs >= half && rhs < half) {
43 return Interval::TypeF(field, lhs, rhs);
44 } else {
45 return Interval::Entire(field);
46 }
47 }
48}
49
51 const auto &lhs = *this;
52 return UnreducedInterval(std::max(lhs.a, rhs.a), std::min(lhs.b, rhs.b));
53}
54
56 const auto &lhs = *this;
57 return UnreducedInterval(std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
58}
59
61 if (isEmpty() || rhs.isEmpty()) {
62 return *this;
63 }
64 DynamicAPInt bound = rhs.b - 1;
65 return UnreducedInterval(a, std::min(b, bound));
66}
67
69 if (isEmpty() || rhs.isEmpty()) {
70 return *this;
71 }
72 return UnreducedInterval(a, std::min(b, rhs.b));
73}
74
76 if (isEmpty() || rhs.isEmpty()) {
77 return *this;
78 }
79 DynamicAPInt bound = rhs.a + 1;
80 return UnreducedInterval(std::max(a, bound), b);
81}
82
84 if (isEmpty() || rhs.isEmpty()) {
85 return *this;
86 }
87 return UnreducedInterval(std::max(a, rhs.a), b);
88}
89
91 if (isEmpty()) {
92 return *this;
93 }
94 return UnreducedInterval(-b, -a);
95}
96
98 DynamicAPInt low = lhs.a + rhs.a, high = lhs.b + rhs.b;
99 return UnreducedInterval(low, high);
100}
101
103 return lhs + (-rhs);
104}
105
107 DynamicAPInt v1 = lhs.a * rhs.a;
108 DynamicAPInt v2 = lhs.a * rhs.b;
109 DynamicAPInt v3 = lhs.b * rhs.a;
110 DynamicAPInt v4 = lhs.b * rhs.b;
111
112 auto minVal = std::min({v1, v2, v3, v4});
113 auto maxVal = std::max({v1, v2, v3, v4});
114
115 return UnreducedInterval(minVal, maxVal);
116}
117
119 return isNotEmpty() && rhs.isNotEmpty() && (b >= rhs.a) && (a <= rhs.b);
120}
121
122std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
123 if ((lhs.a < rhs.a) || ((lhs.a == rhs.a) && (lhs.b < rhs.b))) {
124 return std::strong_ordering::less;
125 }
126 if ((lhs.a > rhs.a) || ((lhs.a == rhs.a) && (lhs.b > rhs.b))) {
127 return std::strong_ordering::greater;
128 }
129 return std::strong_ordering::equal;
130}
131
132DynamicAPInt UnreducedInterval::width() const {
133 DynamicAPInt w;
134 if (a > b) {
135 // This would be reduced to an empty Interval, so the width is just zero.
136 w = 0;
137 } else {
138 // Since the range is inclusive, we add one to the difference to get the true width.
139 w = (b - a) + 1;
140 }
141 ensure(w >= 0, "cannot have negative width");
142 return w;
143}
144
145/* Interval */
146
147const Field &checkFields(const Interval &lhs, const Interval &rhs) {
148 ensure(
149 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
150 );
151 return lhs.getField();
152}
153
155 if (isEmpty()) {
156 // Since ranges are inclusive, empty is encoded as `[a, b]` where `a` > `b`.
157 // This matches the definition provided by UnreducedInterval::width().
158 return UnreducedInterval(field.get().one(), field.get().zero());
159 }
160 if (isEntire()) {
161 return UnreducedInterval(field.get().zero(), field.get().maxVal());
162 }
163 return UnreducedInterval(a, b);
164}
165
167 if (is<Type::TypeF>()) {
168 return UnreducedInterval(a - field.get().prime(), b);
169 }
170 return toUnreduced();
171}
172
174 ensure(is<Type::TypeA, Type::TypeB, Type::TypeC>(), "unsupported range type");
175 return UnreducedInterval(a - field.get().prime(), b - field.get().prime());
176}
177
179 const auto &lhs = *this;
180 const Field &f = checkFields(lhs, rhs);
181
182 // Trivial cases
183 if (lhs.isEntire() || rhs.isEntire()) {
184 return Interval::Entire(f);
185 }
186 if (lhs.isEmpty()) {
187 return rhs;
188 }
189 if (rhs.isEmpty()) {
190 return lhs;
191 }
192 if (lhs.isDegenerate() || rhs.isDegenerate()) {
193 return lhs.toUnreduced().doUnion(rhs.toUnreduced()).reduce(f);
194 }
195
196 // More complex cases
197 if (areOneOf<
200 auto newLhs = std::min(lhs.a, rhs.a);
201 auto newRhs = std::max(lhs.b, rhs.b);
202 if (newLhs == newRhs) {
203 return Interval::Degenerate(f, newLhs);
204 }
205 return Interval(rhs.ty, f, newLhs, newRhs);
206 }
208 auto lhsUnred = lhs.firstUnreduced();
209 auto opt1 = rhs.firstUnreduced().doUnion(lhsUnred);
210 auto opt2 = rhs.secondUnreduced().doUnion(lhsUnred);
211 if (opt1.width() <= opt2.width()) {
212 return opt1.reduce(f);
213 }
214 return opt2.reduce(f);
215 }
217 return lhs.firstUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
218 }
220 return lhs.secondUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
221 }
223 return Interval::Entire(f);
224 }
225 if (areOneOf<
228 lhs, rhs
229 )) {
230 return rhs.join(lhs);
231 }
232 llvm::report_fatal_error("unhandled join case");
233 return Interval::Entire(f);
234}
235
237 const auto &lhs = *this;
238 const Field &f = checkFields(lhs, rhs);
239 // Trivial cases
240 if (lhs == rhs) {
241 return lhs;
242 }
243 if (lhs.isEmpty() || rhs.isEmpty()) {
244 return Interval::Empty(f);
245 }
246 if (lhs.isEntire()) {
247 return rhs;
248 }
249 if (rhs.isEntire()) {
250 return lhs;
251 }
252 if (lhs.isDegenerate() && rhs.isDegenerate()) {
253 // These must not be equal
254 return Interval::Empty(f);
255 }
256 if (lhs.isDegenerate()) {
257 return Interval::TypeA(f, lhs.a, lhs.a).intersect(rhs);
258 }
259 if (rhs.isDegenerate()) {
260 return Interval::TypeA(f, rhs.a, rhs.a).intersect(lhs);
261 }
262
263 // More complex cases
264 if (areOneOf<
267 auto maxA = std::max(lhs.a, rhs.a);
268 auto minB = std::min(lhs.b, rhs.b);
269 if (maxA < minB) {
270 return Interval(lhs.ty, f, maxA, minB);
271 } else if (maxA == minB) {
272 return Interval::Degenerate(f, maxA);
273 } else {
274 return Interval::Empty(f);
275 }
276 }
278 return Interval::Empty(f);
279 }
281 return lhs.firstUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
282 }
284 return lhs.secondUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
285 }
287 auto rhsUnred = rhs.firstUnreduced();
288 auto opt1 = lhs.firstUnreduced().intersect(rhsUnred).reduce(f);
289 auto opt2 = lhs.secondUnreduced().intersect(rhsUnred).reduce(f);
290 ensure(!opt1.isEntire() && !opt2.isEntire(), "impossible intersection");
291 if (opt1.isEmpty()) {
292 return opt2;
293 }
294 if (opt2.isEmpty()) {
295 return opt1;
296 }
297 return opt1.join(opt2);
298 }
299 if (areOneOf<
302 lhs, rhs
303 )) {
304 return rhs.intersect(lhs);
305 }
306 return Interval::Empty(f);
307}
308
310 const Field &f = checkFields(*this, other);
311 // intersect checks that we're in the same field
313 if (intersection.isEmpty()) {
314 // There's nothing to remove, so just return this
315 return *this;
316 }
317
318 // Trivial cases with a non-empty intersection
319 if (isDegenerate() || other.isEntire()) {
320 return Interval::Empty(f);
321 }
322 if (isEntire()) {
323 // Since we don't support punching arbitrary holes in ranges, we only reduce
324 // entire ranges if other is [0, b] or [a, prime - 1]
325 if (other.a == f.zero()) {
326 return UnreducedInterval(other.b + f.one(), f.maxVal()).reduce(f);
327 }
328 if (other.b == f.maxVal()) {
329 return UnreducedInterval(f.zero(), other.a - f.one()).reduce(f);
330 }
331
332 return *this;
333 }
334
335 // Non-trivial cases
336 // - Internal+internal or external+external cases
339 areOneOf<{Type::TypeF, Type::TypeF}>(*this, intersection)) {
340 // The intersection needs to be at the end of the interval, otherwise we would
341 // split the interval in two, and we aren't set up to support multiple intervals
342 // per value.
343 if (a != intersection.a && b != intersection.b) {
344 return *this;
345 }
346 // Otherwise, remove the intersection and reduce
347 if (a == intersection.a) {
348 return UnreducedInterval(intersection.b + f.one(), b).reduce(f);
349 }
350 // else b == intersection.b
351 return UnreducedInterval(a, intersection.a - f.one()).reduce(f);
352 }
353 // - Mixed internal/external cases. We flip the comparison
354 if (isTypeF()) {
355 if (a != intersection.b && b != intersection.a) {
356 return *this;
357 }
358 // Otherwise, remove the intersection and reduce
359 if (a == intersection.b) {
360 return UnreducedInterval(intersection.a + f.one(), b).reduce(f);
361 }
362 // else b == intersection.a
363 return UnreducedInterval(a, intersection.b - f.one()).reduce(f);
364 }
365
366 // In cases we don't know how to handle, we over-approximate and return
367 // the original interval.
368 return *this;
369}
370
371Interval Interval::operator-() const { return (-firstUnreduced()).reduce(field.get()); }
372
374 return Interval::Degenerate(field.get(), field.get().one()) - *this;
375}
376
378 const Field &f = checkFields(lhs, rhs);
379 if (lhs.isEmpty() || rhs.isEntire()) {
380 return rhs;
381 }
382 if (rhs.isEmpty() || lhs.isEntire()) {
383 return lhs;
384 }
385 return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(f);
386}
387
388Interval operator-(const Interval &lhs, const Interval &rhs) { return lhs + (-rhs); }
389
391 const Field &f = checkFields(lhs, rhs);
392 auto zeroInterval = Interval::Degenerate(f, f.zero());
393 if (lhs == zeroInterval || rhs == zeroInterval) {
394 return zeroInterval;
395 }
396 if (lhs.isEmpty() || rhs.isEmpty()) {
397 return Interval::Empty(f);
398 }
399 if (lhs.isEntire() || rhs.isEntire()) {
400 return Interval::Entire(f);
401 }
402
404 return (lhs.secondUnreduced() * rhs.secondUnreduced()).reduce(f);
405 }
406 return (lhs.firstUnreduced() * rhs.firstUnreduced()).reduce(f);
407}
408
409FailureOr<Interval> operator/(const Interval &lhs, const Interval &rhs) {
410 const Field &f = checkFields(lhs, rhs);
411 if (rhs.width() > f.one()) {
412 return Interval::Entire(f);
413 }
414 if (rhs.a == 0) {
415 return failure();
416 }
417 return success(UnreducedInterval(lhs.a / rhs.a, lhs.b / rhs.a).reduce(f));
418}
419
421 const Field &f = checkFields(lhs, rhs);
422 return UnreducedInterval(f.zero(), rhs.b).reduce(f);
423}
424
426 const Field &f = checkFields(lhs, rhs);
427 if (lhs.isEmpty() || rhs.isEmpty()) {
428 return Interval::Empty(f);
429 }
430 if (lhs.isDegenerate() && rhs.isDegenerate()) {
431 return Interval::Degenerate(f, lhs.a & rhs.a);
432 } else if (lhs.isDegenerate()) {
433 return UnreducedInterval(f.zero(), lhs.a).reduce(f);
434 } else if (rhs.isDegenerate()) {
435 return UnreducedInterval(f.zero(), rhs.a).reduce(f);
436 }
437 return Interval::Entire(f);
438}
439
441 const Field &f = checkFields(lhs, rhs);
442 if (lhs.isEmpty() || rhs.isEmpty()) {
443 return Interval::Empty(f);
444 }
445 if (lhs.isDegenerate() && rhs.isDegenerate()) {
446 if (rhs.a > f.bitWidth()) {
447 return Interval::Entire(f);
448 }
449
450 DynamicAPInt v = lhs.a << rhs.a;
451 return UnreducedInterval(v, v).reduce(f);
452 }
453 return Interval::Entire(f);
454}
455
457 const Field &f = checkFields(lhs, rhs);
458 if (lhs.isEmpty() || rhs.isEmpty()) {
459 return Interval::Empty(f);
460 }
461 if (lhs.isDegenerate() && rhs.isDegenerate()) {
462 if (rhs.a > f.bitWidth()) {
463 return Interval::Degenerate(f, f.zero());
464 }
465
466 return Interval::Degenerate(f, lhs.a >> rhs.a);
467 }
468 return Interval::Entire(f);
469}
470
471DynamicAPInt Interval::width() const {
472 switch (ty) {
473 case Type::Empty:
474 return field.get().zero();
475 case Type::Degenerate:
476 return field.get().one();
477 case Type::Entire:
478 return field.get().prime();
479 default:
480 return field.get().reduce(toUnreduced().width());
481 }
482}
483
485 ensure(
486 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
487 );
488 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
489 const auto &field = rhs.getField();
490
491 if (lhs.isBoolFalse() || rhs.isBoolFalse()) {
492 return Interval::False(field);
493 }
494 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
495 return Interval::True(field);
496 }
497
498 return Interval::Boolean(field);
499}
500
502 ensure(
503 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
504 );
505 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
506 const auto &field = rhs.getField();
507
508 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
509 return Interval::False(field);
510 }
511 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
512 return Interval::True(field);
513 }
514
515 return Interval::Boolean(field);
516}
517
519 ensure(
520 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
521 );
522 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
523 const auto &field = rhs.getField();
524
525 // Xor-ing anything with [0, 1] could still result in either case, so just return
526 // the full boolean range.
527 if (lhs.isBoolEither() || rhs.isBoolEither()) {
528 return Interval::Boolean(lhs.getField());
529 }
530
531 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
532 return Interval::False(field);
533 }
534 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
535 return Interval::True(field);
536 }
537 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
538 return Interval::False(field);
539 }
540
541 return Interval::Boolean(field);
542}
543
545 ensure(iv.isBoolean(), "operation only supported for boolean-type intervals");
546 const auto &field = iv.getField();
547
548 if (iv.isBoolTrue()) {
549 return Interval::False(field);
550 }
551 if (iv.isBoolFalse()) {
552 return Interval::True(field);
553 }
554
555 return iv;
556}
557
558void Interval::print(mlir::raw_ostream &os) const {
559 os << TypeName(ty);
560 if (is<Type::Degenerate>()) {
561 os << '(' << a << ')';
562 } else if (!is<Type::Entire, Type::Empty>()) {
563 os << ":[ " << a << ", " << b << " ]";
564 }
565}
566
567} // namespace llzk
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:27
llvm::DynamicAPInt half() const
Returns p / 2.
Definition Field.h:42
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
Definition Field.h:48
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:39
llvm::DynamicAPInt reduce(const llvm::DynamicAPInt &i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:51
unsigned bitWidth() const
Definition Field.h:68
llvm::DynamicAPInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
Definition Field.h:54
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:304
static Interval True(const Field &f)
Definition Intervals.h:219
llvm::DynamicAPInt rhs() const
Definition Intervals.h:329
Interval intersect(const Interval &rhs) const
Intersect.
bool isBoolean() const
Definition Intervals.h:316
static std::string_view TypeName(Type t)
Definition Intervals.h:207
void print(llvm::raw_ostream &os) const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
bool isBoolFalse() const
Definition Intervals.h:313
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
bool isDegenerate() const
Definition Intervals.h:306
const Field & getField() const
Definition Intervals.h:324
bool isBoolTrue() const
Definition Intervals.h:314
bool is() const
Definition Intervals.h:318
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
bool isTypeF() const
Definition Intervals.h:311
static Interval False(const Field &f)
Definition Intervals.h:217
Interval operator~() const
llvm::DynamicAPInt lhs() const
Definition Intervals.h:328
static bool areOneOf(const Interval &a, const Interval &b)
Definition Intervals.h:257
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
Definition Intervals.h:243
llvm::DynamicAPInt width() const
bool isEntire() const
Definition Intervals.h:307
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
Definition Intervals.h:26
UnreducedInterval operator-() const
Definition Intervals.cpp:90
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
Definition Intervals.cpp:50
UnreducedInterval(const llvm::DynamicAPInt &x, const llvm::DynamicAPInt &y)
Definition Intervals.h:28
bool isEmpty() const
Returns true iff width() is zero.
Definition Intervals.h:114
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:60
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:83
bool isNotEmpty() const
Definition Intervals.h:116
bool overlaps(const UnreducedInterval &rhs) const
llvm::DynamicAPInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
Definition Intervals.cpp:55
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:75
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:20
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:68
void ensure(bool condition, const llvm::Twine &errMsg)
FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval operator-(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
Interval operator>>(const Interval &lhs, const Interval &rhs)
Interval operator&(const Interval &lhs, const Interval &rhs)
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
const Field & checkFields(const Interval &lhs, const Interval &rhs)
UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
Definition Intervals.cpp:97
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)