LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ConstrainRef.h
Go to the documentation of this file.
1//===-- ConstrainRef.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
18#include "llzk/Util/Hash.h"
19
20#include <mlir/Analysis/DataFlowFramework.h>
21#include <mlir/Dialect/Arith/IR/Arith.h>
22#include <mlir/Pass/AnalysisManager.h>
23
24#include <llvm/ADT/EquivalenceClasses.h>
25
26#include <unordered_set>
27#include <vector>
28
29namespace llzk {
30
35 using IndexRange = std::pair<mlir::APInt, mlir::APInt>;
36
37public:
39 explicit ConstrainRefIndex(mlir::APInt i) : index(i) {}
40 explicit ConstrainRefIndex(int64_t i) : index(toAPInt(i)) {}
41 ConstrainRefIndex(mlir::APInt low, mlir::APInt high) : index(IndexRange {low, high}) {}
42 explicit ConstrainRefIndex(IndexRange r) : index(r) {}
43
44 bool isField() const {
45 return std::holds_alternative<SymbolLookupResult<component::FieldDefOp>>(index);
46 }
48 ensure(isField(), "ConstrainRefIndex: field requested but not contained");
49 return std::get<SymbolLookupResult<component::FieldDefOp>>(index).get();
50 }
51
52 bool isIndex() const { return std::holds_alternative<mlir::APInt>(index); }
53 mlir::APInt getIndex() const {
54 ensure(isIndex(), "ConstrainRefIndex: index requested but not contained");
55 return std::get<mlir::APInt>(index);
56 }
57
58 bool isIndexRange() const { return std::holds_alternative<IndexRange>(index); }
59 IndexRange getIndexRange() const {
60 ensure(isIndexRange(), "ConstrainRefIndex: index range requested but not contained");
61 return std::get<IndexRange>(index);
62 }
63
64 inline void dump() const { print(llvm::errs()); }
65 void print(mlir::raw_ostream &os) const;
66
67 inline bool operator==(const ConstrainRefIndex &rhs) const { return index == rhs.index; }
68
69 bool operator<(const ConstrainRefIndex &rhs) const;
70
71 bool operator>(const ConstrainRefIndex &rhs) const { return rhs < *this; }
72
73 struct Hash {
74 size_t operator()(const ConstrainRefIndex &c) const {
75 if (c.isIndex()) {
76 return llvm::hash_value(c.getIndex());
77 } else if (c.isIndexRange()) {
78 auto r = c.getIndexRange();
79 return llvm::hash_value(std::get<0>(r)) ^ llvm::hash_value(std::get<1>(r));
80 } else {
82 }
83 }
84 };
85
86 size_t getHash() const { return Hash {}(*this); }
87
88private:
94 std::variant<SymbolLookupResult<component::FieldDefOp>, mlir::APInt, IndexRange> index;
95};
96
97static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefIndex &rhs) {
98 rhs.print(os);
99 return os;
100}
101
112
118 static std::vector<ConstrainRef> getAllConstrainRefs(
119 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, array::ArrayType arrayTy,
120 mlir::BlockArgument blockArg, std::vector<ConstrainRefIndex> fields
121 );
122
127 static std::vector<ConstrainRef> getAllConstrainRefs(
128 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod,
129 SymbolLookupResult<component::StructDefOp> s, mlir::BlockArgument blockArg,
130 std::vector<ConstrainRefIndex> fields
131 );
132
134 static std::vector<ConstrainRef> getAllConstrainRefs(
135 mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, mlir::BlockArgument arg
136 );
137
138public:
140 static std::vector<ConstrainRef> getAllConstrainRefs(component::StructDefOp structDef);
141
142 explicit ConstrainRef(mlir::BlockArgument b)
143 : blockArg(b), fieldRefs(), constantVal(std::nullopt) {}
144 ConstrainRef(mlir::BlockArgument b, std::vector<ConstrainRefIndex> f)
145 : blockArg(b), fieldRefs(std::move(f)), constantVal(std::nullopt) {}
146 explicit ConstrainRef(felt::FeltConstantOp c) : blockArg(nullptr), fieldRefs(), constantVal(c) {}
147 explicit ConstrainRef(mlir::arith::ConstantIndexOp c)
148 : blockArg(nullptr), fieldRefs(), constantVal(c) {}
150 : blockArg(nullptr), fieldRefs(), constantVal(c) {}
151
152 mlir::Type getType() const;
153
154 bool isConstantFelt() const {
155 return constantVal.has_value() && std::holds_alternative<felt::FeltConstantOp>(*constantVal);
156 }
157 bool isConstantIndex() const {
158 return constantVal.has_value() &&
159 std::holds_alternative<mlir::arith::ConstantIndexOp>(*constantVal);
160 }
161 bool isTemplateConstant() const {
162 return constantVal.has_value() &&
163 std::holds_alternative<polymorphic::ConstReadOp>(*constantVal);
164 }
165 bool isConstant() const { return constantVal.has_value(); }
166
167 bool isFeltVal() const { return mlir::isa<felt::FeltType>(getType()); }
168 bool isIndexVal() const { return mlir::isa<mlir::IndexType>(getType()); }
169 bool isIntegerVal() const { return mlir::isa<mlir::IntegerType>(getType()); }
170 bool isTypeVarVal() const { return mlir::isa<polymorphic::TypeVarType>(getType()); }
171 bool isScalar() const {
172 return isConstant() || isFeltVal() || isIndexVal() || isIntegerVal() || isTypeVarVal();
173 }
174 bool isSignal() const { return isSignalType(getType()); }
175
176 bool isBlockArgument() const { return blockArg != nullptr; }
177 mlir::BlockArgument getBlockArgument() const {
178 ensure(isBlockArgument(), "is not a block argument");
179 return blockArg;
180 }
181 unsigned getInputNum() const { return blockArg.getArgNumber(); }
182
183 mlir::APInt getConstantFeltValue() const {
184 ensure(isConstantFelt(), __FUNCTION__ + mlir::Twine(" requires a constant felt!"));
185 return std::get<felt::FeltConstantOp>(*constantVal).getValueAttr().getValue();
186 }
187 mlir::APInt getConstantIndexValue() const {
188 ensure(isConstantIndex(), __FUNCTION__ + mlir::Twine(" requires a constant index!"));
189 return toAPInt(std::get<mlir::arith::ConstantIndexOp>(*constantVal).value());
190 }
191 mlir::APInt getConstantInt() const {
192 ensure(
194 __FUNCTION__ + mlir::Twine(" requires a constant int type!")
195 );
197 }
198
200 bool isValidPrefix(const ConstrainRef &prefix) const;
201
206 mlir::FailureOr<std::vector<ConstrainRefIndex>> getSuffix(const ConstrainRef &prefix) const;
207
214 mlir::FailureOr<ConstrainRef>
215 translate(const ConstrainRef &prefix, const ConstrainRef &other) const;
216
219 mlir::FailureOr<ConstrainRef> getParentPrefix() const {
220 if (isConstantFelt() || fieldRefs.empty()) {
221 return mlir::failure();
222 }
223 auto copy = *this;
224 copy.fieldRefs.pop_back();
225 return copy;
226 }
227
229 auto copy = *this;
230 copy.fieldRefs.push_back(r);
231 return copy;
232 }
233
235 auto copy = *this;
236 assert(other.isConstantIndex());
237 copy.fieldRefs.push_back(ConstrainRefIndex(other.getConstantIndexValue()));
238 return copy;
239 }
240
241 const std::vector<ConstrainRefIndex> &getPieces() const { return fieldRefs; }
242
243 void print(mlir::raw_ostream &os) const;
244 void dump() const { print(llvm::errs()); }
245
246 bool operator==(const ConstrainRef &rhs) const;
247
248 bool operator!=(const ConstrainRef &rhs) const { return !(*this == rhs); }
249
250 // required for EquivalenceClasses usage
251 bool operator<(const ConstrainRef &rhs) const;
252
253 bool operator>(const ConstrainRef &rhs) const { return rhs < *this; }
254
255 struct Hash {
256 size_t operator()(const ConstrainRef &val) const;
257 };
258
259private:
264 mlir::BlockArgument blockArg;
265
266 std::vector<ConstrainRefIndex> fieldRefs;
267 // using mutable to reduce constant casts for certain get* functions.
268 mutable std::optional<
269 std::variant<felt::FeltConstantOp, mlir::arith::ConstantIndexOp, polymorphic::ConstReadOp>>
270 constantVal;
271};
272
273mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs);
274
275/* ConstrainRefSet */
276
277class ConstrainRefSet : public std::unordered_set<ConstrainRef, ConstrainRef::Hash> {
278 using Base = std::unordered_set<ConstrainRef, ConstrainRef::Hash>;
279
280public:
281 using Base::Base;
282
284
285 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs);
286};
287
288static_assert(
290 "ConstrainRefSet must satisfy the ScalarLatticeValue requirements"
291);
292
293} // namespace llzk
294
295namespace llvm {
296
297template <> struct DenseMapInfo<llzk::ConstrainRef> {
299 return llzk::ConstrainRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
300 }
302 return llzk::ConstrainRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
303 }
304 static unsigned getHashValue(const llzk::ConstrainRef &ref) {
305 if (ref == getEmptyKey() || ref == getTombstoneKey()) {
306 return llvm::hash_value(ref.getBlockArgument().getAsOpaquePointer());
307 }
308 return llzk::ConstrainRef::Hash {}(ref);
309 }
310 static bool isEqual(const llzk::ConstrainRef &lhs, const llzk::ConstrainRef &rhs) {
311 return lhs == rhs;
312 }
313};
314
315} // namespace llvm
Defines an index into an LLZK object.
component::FieldDefOp getField() const
void print(mlir::raw_ostream &os) const
bool isIndexRange() const
ConstrainRefIndex(mlir::APInt i)
size_t getHash() const
bool operator<(const ConstrainRefIndex &rhs) const
bool operator==(const ConstrainRefIndex &rhs) const
ConstrainRefIndex(SymbolLookupResult< component::FieldDefOp > f)
IndexRange getIndexRange() const
mlir::APInt getIndex() const
ConstrainRefIndex(IndexRange r)
ConstrainRefIndex(mlir::APInt low, mlir::APInt high)
bool operator>(const ConstrainRefIndex &rhs) const
ConstrainRefSet & join(const ConstrainRefSet &rhs)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRefSet &rhs)
Defines a reference to a llzk object within a constrain function call.
mlir::APInt getConstantFeltValue() const
bool isTypeVarVal() const
const std::vector< ConstrainRefIndex > & getPieces() const
bool operator<(const ConstrainRef &rhs) const
ConstrainRef(felt::FeltConstantOp c)
bool isValidPrefix(const ConstrainRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
bool isIntegerVal() const
void print(mlir::raw_ostream &os) const
bool isConstantFelt() const
bool isConstant() const
ConstrainRef createChild(ConstrainRef other) const
ConstrainRef(polymorphic::ConstReadOp c)
bool isTemplateConstant() const
ConstrainRef(mlir::BlockArgument b, std::vector< ConstrainRefIndex > f)
mlir::APInt getConstantIndexValue() const
bool isConstantIndex() const
ConstrainRef(mlir::BlockArgument b)
mlir::APInt getConstantInt() const
bool operator==(const ConstrainRef &rhs) const
mlir::FailureOr< ConstrainRef > getParentPrefix() const
Create a new reference that is the immediate prefix of this reference if possible.
ConstrainRef(mlir::arith::ConstantIndexOp c)
bool isBlockArgument() const
mlir::BlockArgument getBlockArgument() const
ConstrainRef createChild(ConstrainRefIndex r) const
bool isSignal() const
bool isIndexVal() const
bool isFeltVal() const
bool operator!=(const ConstrainRef &rhs) const
mlir::FailureOr< ConstrainRef > translate(const ConstrainRef &prefix, const ConstrainRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
mlir::Type getType() const
void dump() const
bool operator>(const ConstrainRef &rhs) const
mlir::FailureOr< std::vector< ConstrainRefIndex > > getSuffix(const ConstrainRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool isScalar() const
unsigned getInputNum() const
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
llvm::APInt toAPInt(int64_t i)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
bool isSignalType(Type type)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static bool isEqual(const llzk::ConstrainRef &lhs, const llzk::ConstrainRef &rhs)
static llzk::ConstrainRef getTombstoneKey()
static llzk::ConstrainRef getEmptyKey()
static unsigned getHashValue(const llzk::ConstrainRef &ref)
size_t operator()(const ConstrainRefIndex &c) const
size_t operator()(const ConstrainRef &val) const