LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Intervals.cpp
Go to the documentation of this file.
1//===-- Intervals.cpp ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13
14using namespace mlir;
15
16namespace llzk {
17
18/* UnreducedInterval */
19
21 if (a > b) {
22 return Interval::Empty(field);
23 }
24 if (width() >= field.prime()) {
25 return Interval::Entire(field);
26 }
27 auto lhs = field.reduce(a), rhs = field.reduce(b);
28 if (rhs == lhs) {
29 return Interval::Degenerate(field, lhs);
30 }
31
32 const auto &half = field.half();
33 if (lhs <= rhs) {
34 if (lhs < half && rhs < half) {
35 return Interval::TypeA(field, lhs, rhs);
36 } else if (lhs < half) {
37 return Interval::TypeC(field, lhs, rhs);
38 } else {
39 return Interval::TypeB(field, lhs, rhs);
40 }
41 } else {
42 if (lhs >= half && rhs < half) {
43 return Interval::TypeF(field, lhs, rhs);
44 } else {
45 return Interval::Entire(field);
46 }
47 }
48}
49
51 const auto &lhs = *this;
52 return UnreducedInterval(std::max(lhs.a, rhs.a), std::min(lhs.b, rhs.b));
53}
54
56 const auto &lhs = *this;
57 return UnreducedInterval(std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
58}
59
61 if (isEmpty() || rhs.isEmpty()) {
62 return *this;
63 }
64 DynamicAPInt bound = rhs.b - 1;
65 return UnreducedInterval(a, std::min(b, bound));
66}
67
69 if (isEmpty() || rhs.isEmpty()) {
70 return *this;
71 }
72 return UnreducedInterval(a, std::min(b, rhs.b));
73}
74
76 if (isEmpty() || rhs.isEmpty()) {
77 return *this;
78 }
79 DynamicAPInt bound = rhs.a + 1;
80 return UnreducedInterval(std::max(a, bound), b);
81}
82
84 if (isEmpty() || rhs.isEmpty()) {
85 return *this;
86 }
87 return UnreducedInterval(std::max(a, rhs.a), b);
88}
89
91 if (isEmpty()) {
92 return *this;
93 }
94 return UnreducedInterval(-b, -a);
95}
96
98 DynamicAPInt low = lhs.a + rhs.a, high = lhs.b + rhs.b;
99 return UnreducedInterval(low, high);
100}
101
103 return lhs + (-rhs);
104}
105
107 DynamicAPInt v1 = lhs.a * rhs.a;
108 DynamicAPInt v2 = lhs.a * rhs.b;
109 DynamicAPInt v3 = lhs.b * rhs.a;
110 DynamicAPInt v4 = lhs.b * rhs.b;
111
112 auto minVal = std::min({v1, v2, v3, v4});
113 auto maxVal = std::max({v1, v2, v3, v4});
114
115 return UnreducedInterval(minVal, maxVal);
116}
117
119 return isNotEmpty() && rhs.isNotEmpty() && (b >= rhs.a) && (a <= rhs.b);
120}
121
122std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
123 if ((lhs.a < rhs.a) || ((lhs.a == rhs.a) && (lhs.b < rhs.b))) {
124 return std::strong_ordering::less;
125 }
126 if ((lhs.a > rhs.a) || ((lhs.a == rhs.a) && (lhs.b > rhs.b))) {
127 return std::strong_ordering::greater;
128 }
129 return std::strong_ordering::equal;
130}
131
132DynamicAPInt UnreducedInterval::width() const {
133 DynamicAPInt w;
134 if (a > b) {
135 // This would be reduced to an empty Interval, so the width is just zero.
136 w = 0;
137 } else {
138 // Since the range is inclusive, we add one to the difference to get the true width.
139 w = (b - a) + 1;
140 }
141 ensure(w >= 0, "cannot have negative width");
142 return w;
143}
144
145/* Interval */
146
147const Field &checkFields(const Interval &lhs, const Interval &rhs) {
148 ensure(
149 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
150 );
151 return lhs.getField();
152}
153
155 if (isEmpty()) {
156 // Since ranges are inclusive, empty is encoded as `[a, b]` where `a` > `b`.
157 // This matches the definition provided by UnreducedInterval::width().
158 return UnreducedInterval(field.get().one(), field.get().zero());
159 }
160 if (isEntire()) {
161 return UnreducedInterval(field.get().zero(), field.get().maxVal());
162 }
163 return UnreducedInterval(a, b);
164}
165
167 if (is<Type::TypeF>()) {
168 return UnreducedInterval(a - field.get().prime(), b);
169 }
170 return toUnreduced();
171}
172
174 ensure(is<Type::TypeA, Type::TypeB, Type::TypeC>(), "unsupported range type");
175 return UnreducedInterval(a - field.get().prime(), b - field.get().prime());
176}
177
179 const auto &lhs = *this;
180 const Field &f = checkFields(lhs, rhs);
181
182 // Trivial cases
183 if (lhs.isEntire() || rhs.isEntire()) {
184 return Interval::Entire(f);
185 }
186 if (lhs.isEmpty()) {
187 return rhs;
188 }
189 if (rhs.isEmpty()) {
190 return lhs;
191 }
192 if (lhs.isDegenerate() || rhs.isDegenerate()) {
193 return lhs.toUnreduced().doUnion(rhs.toUnreduced()).reduce(f);
194 }
195
196 // More complex cases
197 if (areOneOf<
200 return Interval(rhs.ty, f, std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
201 }
203 auto lhsUnred = lhs.firstUnreduced();
204 auto opt1 = rhs.firstUnreduced().doUnion(lhsUnred);
205 auto opt2 = rhs.secondUnreduced().doUnion(lhsUnred);
206 if (opt1.width() <= opt2.width()) {
207 return opt1.reduce(f);
208 }
209 return opt2.reduce(f);
210 }
212 return lhs.firstUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
213 }
215 return lhs.secondUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
216 }
218 return Interval::Entire(f);
219 }
220 if (areOneOf<
223 lhs, rhs
224 )) {
225 return rhs.join(lhs);
226 }
227 llvm::report_fatal_error("unhandled join case");
228 return Interval::Entire(f);
229}
230
232 const auto &lhs = *this;
233 const Field &f = checkFields(lhs, rhs);
234 // Trivial cases
235 if (lhs.isEmpty() || rhs.isEmpty()) {
236 return Interval::Empty(f);
237 }
238 if (lhs.isEntire()) {
239 return rhs;
240 }
241 if (rhs.isEntire()) {
242 return lhs;
243 }
244 if (lhs.isDegenerate() || rhs.isDegenerate()) {
245 return lhs.toUnreduced().intersect(rhs.toUnreduced()).reduce(f);
246 }
247
248 // More complex cases
249 if (areOneOf<
252 auto maxA = std::max(lhs.a, rhs.a);
253 auto minB = std::min(lhs.b, rhs.b);
254 if (maxA <= minB) {
255 return Interval(lhs.ty, f, maxA, minB);
256 } else {
257 return Interval::Empty(f);
258 }
259 }
261 return Interval::Empty(f);
262 }
264 return lhs.firstUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
265 }
267 return lhs.secondUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
268 }
270 auto rhsUnred = rhs.firstUnreduced();
271 auto opt1 = lhs.firstUnreduced().intersect(rhsUnred).reduce(f);
272 auto opt2 = lhs.secondUnreduced().intersect(rhsUnred).reduce(f);
273 ensure(!opt1.isEntire() && !opt2.isEntire(), "impossible intersection");
274 if (opt1.isEmpty()) {
275 return opt2;
276 }
277 if (opt2.isEmpty()) {
278 return opt1;
279 }
280 return opt1.join(opt2);
281 }
282 if (areOneOf<
285 lhs, rhs
286 )) {
287 return rhs.intersect(lhs);
288 }
289 return Interval::Empty(f);
290}
291
293 const Field &f = checkFields(*this, other);
294 // intersect checks that we're in the same field
296 if (intersection.isEmpty()) {
297 // There's nothing to remove, so just return this
298 return *this;
299 }
300
301 // Trivial cases with a non-empty intersection
302 if (isDegenerate() || other.isEntire()) {
303 return Interval::Empty(f);
304 }
305 if (isEntire()) {
306 // Since we don't support punching arbitrary holes in ranges, we only reduce
307 // entire ranges if other is [0, b] or [a, prime - 1]
308 if (other.a == f.zero()) {
309 return UnreducedInterval(other.b + f.one(), f.maxVal()).reduce(f);
310 }
311 if (other.b == f.maxVal()) {
312 return UnreducedInterval(f.zero(), other.a - f.one()).reduce(f);
313 }
314
315 return *this;
316 }
317
318 // Non-trivial cases
319 // - Internal+internal or external+external cases
322 areOneOf<{Type::TypeF, Type::TypeF}>(*this, intersection)) {
323 // The intersection needs to be at the end of the interval, otherwise we would
324 // split the interval in two, and we aren't set up to support multiple intervals
325 // per value.
326 if (a != intersection.a && b != intersection.b) {
327 return *this;
328 }
329 // Otherwise, remove the intersection and reduce
330 if (a == intersection.a) {
331 return UnreducedInterval(intersection.b + f.one(), b).reduce(f);
332 }
333 // else b == intersection.b
334 return UnreducedInterval(a, intersection.a - f.one()).reduce(f);
335 }
336 // - Mixed internal/external cases. We flip the comparison
337 if (isTypeF()) {
338 if (a != intersection.b && b != intersection.a) {
339 return *this;
340 }
341 // Otherwise, remove the intersection and reduce
342 if (a == intersection.b) {
343 return UnreducedInterval(intersection.a + f.one(), b).reduce(f);
344 }
345 // else b == intersection.a
346 return UnreducedInterval(a, intersection.b - f.one()).reduce(f);
347 }
348
349 // In cases we don't know how to handle, we over-approximate and return
350 // the original interval.
351 return *this;
352}
353
354Interval Interval::operator-() const { return (-firstUnreduced()).reduce(field.get()); }
355
357 return Interval::Degenerate(field.get(), field.get().one()) - *this;
358}
359
361 const Field &f = checkFields(lhs, rhs);
362 if (lhs.isEmpty()) {
363 return rhs;
364 }
365 if (rhs.isEmpty()) {
366 return lhs;
367 }
368 return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(f);
369}
370
371Interval operator-(const Interval &lhs, const Interval &rhs) { return lhs + (-rhs); }
372
374 const Field &f = checkFields(lhs, rhs);
375 auto zeroInterval = Interval::Degenerate(f, f.zero());
376 if (lhs == zeroInterval || rhs == zeroInterval) {
377 return zeroInterval;
378 }
379 if (lhs.isEmpty() || rhs.isEmpty()) {
380 return Interval::Empty(f);
381 }
382 if (lhs.isEntire() || rhs.isEntire()) {
383 return Interval::Entire(f);
384 }
385
387 return (lhs.secondUnreduced() * rhs.secondUnreduced()).reduce(f);
388 }
389 return (lhs.firstUnreduced() * rhs.firstUnreduced()).reduce(f);
390}
391
392FailureOr<Interval> operator/(const Interval &lhs, const Interval &rhs) {
393 const Field &f = checkFields(lhs, rhs);
394 if (rhs.width() > f.one()) {
395 return Interval::Entire(f);
396 }
397 if (rhs.a == 0) {
398 return failure();
399 }
400 return success(UnreducedInterval(lhs.a / rhs.a, lhs.b / rhs.a).reduce(f));
401}
402
404 const Field &f = checkFields(lhs, rhs);
405 return UnreducedInterval(f.zero(), rhs.b).reduce(f);
406}
407
409 const Field &f = checkFields(lhs, rhs);
410 if (lhs.isEmpty() || rhs.isEmpty()) {
411 return Interval::Empty(f);
412 }
413 if (lhs.isDegenerate() && rhs.isDegenerate()) {
414 return Interval::Degenerate(f, lhs.a & rhs.a);
415 } else if (lhs.isDegenerate()) {
416 return UnreducedInterval(f.zero(), lhs.a).reduce(f);
417 } else if (rhs.isDegenerate()) {
418 return UnreducedInterval(f.zero(), rhs.a).reduce(f);
419 }
420 return Interval::Entire(f);
421}
422
424 const Field &f = checkFields(lhs, rhs);
425 if (lhs.isEmpty() || rhs.isEmpty()) {
426 return Interval::Empty(f);
427 }
428 if (lhs.isDegenerate() && rhs.isDegenerate()) {
429 if (rhs.a > f.bitWidth()) {
430 return Interval::Entire(f);
431 }
432
433 DynamicAPInt v = lhs.a << rhs.a;
434 return UnreducedInterval(v, v).reduce(f);
435 }
436 return Interval::Entire(f);
437}
438
440 const Field &f = checkFields(lhs, rhs);
441 if (lhs.isEmpty() || rhs.isEmpty()) {
442 return Interval::Empty(f);
443 }
444 if (lhs.isDegenerate() && rhs.isDegenerate()) {
445 if (rhs.a > f.bitWidth()) {
446 return Interval::Degenerate(f, f.zero());
447 }
448
449 return Interval::Degenerate(f, lhs.a >> rhs.a);
450 }
451 return Interval::Entire(f);
452}
453
454DynamicAPInt Interval::width() const {
455 switch (ty) {
456 case Type::Empty:
457 return field.get().zero();
458 case Type::Degenerate:
459 return field.get().one();
460 case Type::Entire:
461 return field.get().prime();
462 default:
463 return field.get().reduce(toUnreduced().width());
464 }
465}
466
468 ensure(
469 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
470 );
471 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
472 const auto &field = rhs.getField();
473
474 if (lhs.isBoolFalse() || rhs.isBoolFalse()) {
475 return Interval::False(field);
476 }
477 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
478 return Interval::True(field);
479 }
480
481 return Interval::Boolean(field);
482}
483
485 ensure(
486 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
487 );
488 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
489 const auto &field = rhs.getField();
490
491 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
492 return Interval::False(field);
493 }
494 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
495 return Interval::True(field);
496 }
497
498 return Interval::Boolean(field);
499}
500
502 ensure(
503 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
504 );
505 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
506 const auto &field = rhs.getField();
507
508 // Xor-ing anything with [0, 1] could still result in either case, so just return
509 // the full boolean range.
510 if (lhs.isBoolEither() || rhs.isBoolEither()) {
511 return Interval::Boolean(lhs.getField());
512 }
513
514 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
515 return Interval::False(field);
516 }
517 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
518 return Interval::True(field);
519 }
520 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
521 return Interval::False(field);
522 }
523
524 return Interval::Boolean(field);
525}
526
528 ensure(iv.isBoolean(), "operation only supported for boolean-type intervals");
529 const auto &field = iv.getField();
530
531 if (iv.isBoolTrue()) {
532 return Interval::False(field);
533 }
534 if (iv.isBoolFalse()) {
535 return Interval::True(field);
536 }
537
538 return iv;
539}
540
541void Interval::print(mlir::raw_ostream &os) const {
542 os << TypeName(ty);
543 if (is<Type::Degenerate>()) {
544 os << '(' << a << ')';
545 } else if (!is<Type::Entire, Type::Empty>()) {
546 os << ":[ " << a << ", " << b << " ]";
547 }
548}
549
550} // namespace llzk
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:25
llvm::DynamicAPInt half() const
Returns p / 2.
Definition Field.h:40
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
Definition Field.h:46
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:37
llvm::DynamicAPInt reduce(const llvm::DynamicAPInt &i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:49
unsigned bitWidth() const
Definition Field.h:61
llvm::DynamicAPInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
Definition Field.h:52
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:304
static Interval True(const Field &f)
Definition Intervals.h:219
llvm::DynamicAPInt rhs() const
Definition Intervals.h:329
Interval intersect(const Interval &rhs) const
Intersect.
bool isBoolean() const
Definition Intervals.h:316
static std::string_view TypeName(Type t)
Definition Intervals.h:207
void print(llvm::raw_ostream &os) const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
bool isBoolFalse() const
Definition Intervals.h:313
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
bool isDegenerate() const
Definition Intervals.h:306
const Field & getField() const
Definition Intervals.h:324
bool isBoolTrue() const
Definition Intervals.h:314
bool is() const
Definition Intervals.h:318
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
bool isTypeF() const
Definition Intervals.h:311
static Interval False(const Field &f)
Definition Intervals.h:217
Interval operator~() const
llvm::DynamicAPInt lhs() const
Definition Intervals.h:328
static bool areOneOf(const Interval &a, const Interval &b)
Definition Intervals.h:257
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
Definition Intervals.h:243
llvm::DynamicAPInt width() const
bool isEntire() const
Definition Intervals.h:307
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
Definition Intervals.h:26
UnreducedInterval operator-() const
Definition Intervals.cpp:90
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
Definition Intervals.cpp:50
UnreducedInterval(const llvm::DynamicAPInt &x, const llvm::DynamicAPInt &y)
Definition Intervals.h:28
bool isEmpty() const
Returns true iff width() is zero.
Definition Intervals.h:114
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:60
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:83
bool isNotEmpty() const
Definition Intervals.h:116
bool overlaps(const UnreducedInterval &rhs) const
llvm::DynamicAPInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
Definition Intervals.cpp:55
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:75
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:20
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:68
void ensure(bool condition, const llvm::Twine &errMsg)
FailureOr< Interval > operator/(const Interval &lhs, const Interval &rhs)
Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval operator-(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
Interval operator>>(const Interval &lhs, const Interval &rhs)
Interval operator&(const Interval &lhs, const Interval &rhs)
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
const Field & checkFields(const Interval &lhs, const Interval &rhs)
UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
Definition Intervals.cpp:97
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)