LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRef.h
Go to the documentation of this file.
1//===-- ConstrainRef.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
18#include "llzk/Util/Hash.h"
19
20#include <mlir/Analysis/DataFlowFramework.h>
21#include <mlir/Dialect/Arith/IR/Arith.h>
22#include <mlir/Pass/AnalysisManager.h>
23
24#include <llvm/ADT/EquivalenceClasses.h>
25
26#include <unordered_set>
27#include <vector>
28
29namespace llzk {
30
35 using IndexRange = std::pair<mlir::APInt, mlir::APInt>;
36
37public:
38 explicit ConstrainRefIndex(component::FieldDefOp f) : index(f) {}
40 explicit ConstrainRefIndex(mlir::APInt i) : index(i) {}
41 explicit ConstrainRefIndex(int64_t i) : index(toAPInt(i)) {}
42 ConstrainRefIndex(mlir::APInt low, mlir::APInt high) : index(IndexRange {low, high}) {}
43 explicit ConstrainRefIndex(IndexRange r) : index(r) {}
44
45 bool isField() const {
46 return std::holds_alternative<SymbolLookupResult<component::FieldDefOp>>(index) ||
47 std::holds_alternative<component::FieldDefOp>(index);
48 }
50 ensure(isField(), "ConstrainRefIndex: field requested but not contained");
51 if (std::holds_alternative<component::FieldDefOp>(index)) {
52 return std::get<component::FieldDefOp>(index);
53 }
54 return std::get<SymbolLookupResult<component::FieldDefOp>>(index).get();
55 }
56
57 bool isIndex() const { return std::holds_alternative<mlir::APInt>(index); }
58 mlir::APInt getIndex() const {
59 ensure(isIndex(), "ConstrainRefIndex: index requested but not contained");
60 return std::get<mlir::APInt>(index);
61 }
62
63 bool isIndexRange() const { return std::holds_alternative<IndexRange>(index); }
64 IndexRange getIndexRange() const {
65 ensure(isIndexRange(), "ConstrainRefIndex: index range requested but not contained");
66 return std::get<IndexRange>(index);
67 }
68
69 inline void dump() const { print(llvm::errs()); }
70 void print(mlir::raw_ostream &os) const;
71
72 inline bool operator==(const ConstrainRefIndex &rhs) const {
73 if (isField() && rhs.isField()) {
74 // We compare the underlying fields, since the field could be in a symbol
75 // lookup or not.
76 return getField() == rhs.getField();
77 }
78 return index == rhs.index;
79 }
80
81 bool operator<(const ConstrainRefIndex &rhs) const;
82
83 bool operator>(const ConstrainRefIndex &rhs) const { return rhs < *this; }
84
85 struct Hash {
86 size_t operator()(const ConstrainRefIndex &c) const;
87 };
88
89 inline size_t getHash() const { return Hash {}(*this); }
90
91private:
98 std::variant<
100 index;
101};
102
103static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefIndex &rhs) {
104 rhs.print(os);
105 return os;
106}
107
118
124 static std::vector<ConstrainRef> getAllConstrainRefs(
125 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, array::ArrayType arrayTy,
126 mlir::BlockArgument blockArg, std::vector<ConstrainRefIndex> fields
127 );
128
133 static std::vector<ConstrainRef> getAllConstrainRefs(
134 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod,
135 SymbolLookupResult<component::StructDefOp> s, mlir::BlockArgument blockArg,
136 std::vector<ConstrainRefIndex> fields
137 );
138
141 static std::vector<ConstrainRef> getAllConstrainRefs(
142 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, mlir::BlockArgument arg,
143 std::vector<ConstrainRefIndex> fields = {}
144 );
145
146public:
148 static std::vector<ConstrainRef> getAllConstrainRefs(component::StructDefOp structDef);
149
152 static std::vector<ConstrainRef>
153 getAllConstrainRefs(component::StructDefOp structDef, component::FieldDefOp fieldDef);
154
155 explicit ConstrainRef(mlir::BlockArgument b)
156 : blockArg(b), fieldRefs(), constantVal(std::nullopt) {}
157 ConstrainRef(mlir::BlockArgument b, std::vector<ConstrainRefIndex> f)
158 : blockArg(b), fieldRefs(std::move(f)), constantVal(std::nullopt) {}
159 explicit ConstrainRef(felt::FeltConstantOp c) : blockArg(nullptr), fieldRefs(), constantVal(c) {}
160 explicit ConstrainRef(mlir::arith::ConstantIndexOp c)
161 : blockArg(nullptr), fieldRefs(), constantVal(c) {}
163 : blockArg(nullptr), fieldRefs(), constantVal(c) {}
164
165 mlir::Type getType() const;
166
167 bool isConstantFelt() const {
168 return constantVal.has_value() && std::holds_alternative<felt::FeltConstantOp>(*constantVal);
169 }
170 bool isConstantIndex() const {
171 return constantVal.has_value() &&
172 std::holds_alternative<mlir::arith::ConstantIndexOp>(*constantVal);
173 }
174 bool isTemplateConstant() const {
175 return constantVal.has_value() &&
176 std::holds_alternative<polymorphic::ConstReadOp>(*constantVal);
177 }
178 bool isConstant() const { return constantVal.has_value(); }
179
180 bool isFeltVal() const { return mlir::isa<felt::FeltType>(getType()); }
181 bool isIndexVal() const { return mlir::isa<mlir::IndexType>(getType()); }
182 bool isIntegerVal() const { return mlir::isa<mlir::IntegerType>(getType()); }
183 bool isTypeVarVal() const { return mlir::isa<polymorphic::TypeVarType>(getType()); }
184 bool isScalar() const {
185 return isConstant() || isFeltVal() || isIndexVal() || isIntegerVal() || isTypeVarVal();
186 }
187 bool isSignal() const { return isSignalType(getType()); }
188
189 bool isBlockArgument() const { return blockArg != nullptr; }
190 mlir::BlockArgument getBlockArgument() const {
191 ensure(isBlockArgument(), "is not a block argument");
192 return blockArg;
193 }
194 unsigned getInputNum() const { return blockArg.getArgNumber(); }
195
196 mlir::APInt getConstantFeltValue() const {
197 ensure(isConstantFelt(), __FUNCTION__ + mlir::Twine(" requires a constant felt!"));
198 return std::get<felt::FeltConstantOp>(*constantVal).getValueAttr().getValue();
199 }
200 mlir::APInt getConstantIndexValue() const {
201 ensure(isConstantIndex(), __FUNCTION__ + mlir::Twine(" requires a constant index!"));
202 return toAPInt(std::get<mlir::arith::ConstantIndexOp>(*constantVal).value());
203 }
204 mlir::APInt getConstantValue() const {
205 ensure(
207 __FUNCTION__ + mlir::Twine(" requires a constant int type!")
208 );
210 }
211
213 bool isValidPrefix(const ConstrainRef &prefix) const;
214
219 mlir::FailureOr<std::vector<ConstrainRefIndex>> getSuffix(const ConstrainRef &prefix) const;
220
227 mlir::FailureOr<ConstrainRef>
228 translate(const ConstrainRef &prefix, const ConstrainRef &other) const;
229
232 mlir::FailureOr<ConstrainRef> getParentPrefix() const {
233 if (isConstantFelt() || fieldRefs.empty()) {
234 return mlir::failure();
235 }
236 auto copy = *this;
237 copy.fieldRefs.pop_back();
238 return copy;
239 }
240
242 auto copy = *this;
243 copy.fieldRefs.push_back(r);
244 return copy;
245 }
246
248 auto copy = *this;
249 assert(other.isConstantIndex());
250 copy.fieldRefs.push_back(ConstrainRefIndex(other.getConstantIndexValue()));
251 return copy;
252 }
253
254 const std::vector<ConstrainRefIndex> &getPieces() const { return fieldRefs; }
255
256 void print(mlir::raw_ostream &os) const;
257 void dump() const { print(llvm::errs()); }
258
259 bool operator==(const ConstrainRef &rhs) const;
260
261 bool operator!=(const ConstrainRef &rhs) const { return !(*this == rhs); }
262
263 // required for EquivalenceClasses usage
264 bool operator<(const ConstrainRef &rhs) const;
265
266 bool operator>(const ConstrainRef &rhs) const { return rhs < *this; }
267
268 struct Hash {
269 size_t operator()(const ConstrainRef &val) const;
270 };
271
272private:
277 mlir::BlockArgument blockArg;
278
279 std::vector<ConstrainRefIndex> fieldRefs;
280 // using mutable to reduce constant casts for certain get* functions.
281 mutable std::optional<
282 std::variant<felt::FeltConstantOp, mlir::arith::ConstantIndexOp, polymorphic::ConstReadOp>>
283 constantVal;
284};
285
286mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs);
287
288/* ConstrainRefSet */
289
290class ConstrainRefSet : public std::unordered_set<ConstrainRef, ConstrainRef::Hash> {
291 using Base = std::unordered_set<ConstrainRef, ConstrainRef::Hash>;
292
293public:
294 using Base::Base;
295
297
298 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs);
299};
300
301static_assert(
303 "ConstrainRefSet must satisfy the ScalarLatticeValue requirements"
304);
305
306} // namespace llzk
307
308namespace llvm {
309
310template <> struct DenseMapInfo<llzk::ConstrainRef> {
312 return llzk::ConstrainRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
313 }
315 return llzk::ConstrainRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
316 }
317 static unsigned getHashValue(const llzk::ConstrainRef &ref) {
318 if (ref == getEmptyKey() || ref == getTombstoneKey()) {
319 return llvm::hash_value(ref.getBlockArgument().getAsOpaquePointer());
320 }
321 return llzk::ConstrainRef::Hash {}(ref);
322 }
323 static bool isEqual(const llzk::ConstrainRef &lhs, const llzk::ConstrainRef &rhs) {
324 return lhs == rhs;
325 }
326};
327
328} // namespace llvm
Defines an index into an LLZK object.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
ConstrainRefIndex(mlir::APInt i)
size_t getHash() const
bool operator<(const ConstrainRefIndex &rhs) const
bool operator==(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(SymbolLookupResult< component::FieldDefOp > f)
ConstrainRefIndex(component::FieldDefOp f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefIndex(IndexRange r)
ConstrainRefIndex(mlir::APInt low, mlir::APInt high)
bool operator>(const ConstrainRefIndex &rhs) const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool isTypeVarVal() const
const std::vector< ConstrainRefIndex > & getPieces() const
bool operator<(const ConstrainRef &rhs) const
ConstrainRef(felt::FeltConstantOp c)
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
bool isIntegerVal() const
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isConstant() const
ConstrainRef createChild(ConstrainRef other) const
ConstrainRef(polymorphic::ConstReadOp c)
bool isTemplateConstant() const
ConstrainRef(mlir::BlockArgument b, std::vector< ConstrainRefIndex > f)
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
ConstrainRef(mlir::BlockArgument b)
bool operator==(const ConstrainRef &rhs) const
mlir::FailureOr< ConstrainRef > getParentPrefix() const
Create a new reference that is the immediate prefix of this reference if possible.
ConstrainRef(mlir::arith::ConstantIndexOp c)
bool isBlockArgument() const
mlir::BlockArgument getBlockArgument() const
ConstrainRef createChild(ConstrainRefIndex r) const
bool isSignal() const
bool isIndexVal() const
bool isFeltVal() const
bool operator!=(const ConstrainRef &rhs) const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::Type getType() const
void dump() const
mlir::APInt getConstantValue() const
bool operator>(const ConstrainRef &rhs) const
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool isScalar() const
unsigned getInputNum() const
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
llvm::APInt toAPInt(int64_t i)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
bool isSignalType(Type type)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static bool isEqual(const llzk::ConstrainRef &lhs, const llzk::ConstrainRef &rhs)
static llzk::ConstrainRef getTombstoneKey()
static llzk::ConstrainRef getEmptyKey()
static unsigned getHashValue(const llzk::ConstrainRef &ref)
size_t operator()(const ConstrainRefIndex &c) const
size_t operator()(const ConstrainRef &val) const