LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SourceRef.h
Go to the documentation of this file.
1//===-- SourceRef.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
19#include "llzk/Util/Hash.h"
20
21#include <mlir/Analysis/DataFlowFramework.h>
22#include <mlir/Dialect/Arith/IR/Arith.h>
23#include <mlir/Pass/AnalysisManager.h>
24
25#include <llvm/ADT/DynamicAPInt.h>
26#include <llvm/ADT/EquivalenceClasses.h>
27
28#include <unordered_set>
29#include <vector>
30
31namespace llzk {
32
37 using IndexRange = std::pair<llvm::DynamicAPInt, llvm::DynamicAPInt>;
38
39public:
40 explicit SourceRefIndex(component::FieldDefOp f) : index(f) {}
42 explicit SourceRefIndex(const llvm::DynamicAPInt &i) : index(i) {}
43 explicit SourceRefIndex(const llvm::APInt &i) : index(toDynamicAPInt(i)) {}
44 explicit SourceRefIndex(int64_t i) : index(llvm::DynamicAPInt(i)) {}
45 SourceRefIndex(const llvm::APInt &low, const llvm::APInt &high)
46 : index(IndexRange {toDynamicAPInt(low), toDynamicAPInt(high)}) {}
47 explicit SourceRefIndex(IndexRange r) : index(r) {}
48
49 bool isField() const {
50 return std::holds_alternative<SymbolLookupResult<component::FieldDefOp>>(index) ||
51 std::holds_alternative<component::FieldDefOp>(index);
52 }
54 ensure(isField(), "SourceRefIndex: field requested but not contained");
55 if (std::holds_alternative<component::FieldDefOp>(index)) {
56 return std::get<component::FieldDefOp>(index);
57 }
58 return std::get<SymbolLookupResult<component::FieldDefOp>>(index).get();
59 }
60
61 bool isIndex() const { return std::holds_alternative<llvm::DynamicAPInt>(index); }
62 llvm::DynamicAPInt getIndex() const {
63 ensure(isIndex(), "SourceRefIndex: index requested but not contained");
64 return std::get<llvm::DynamicAPInt>(index);
65 }
66
67 bool isIndexRange() const { return std::holds_alternative<IndexRange>(index); }
68 IndexRange getIndexRange() const {
69 ensure(isIndexRange(), "SourceRefIndex: index range requested but not contained");
70 return std::get<IndexRange>(index);
71 }
72
73 inline void dump() const { print(llvm::errs()); }
74 void print(mlir::raw_ostream &os) const;
75
76 inline bool operator==(const SourceRefIndex &rhs) const {
77 if (isField() && rhs.isField()) {
78 // We compare the underlying fields, since the field could be in a symbol
79 // lookup or not.
80 return getField() == rhs.getField();
81 }
82 if (isIndex() && rhs.isIndex()) {
83 return getIndex() == rhs.getIndex();
84 }
85 return index == rhs.index;
86 }
87
88 bool operator<(const SourceRefIndex &rhs) const;
89
90 bool operator>(const SourceRefIndex &rhs) const { return rhs < *this; }
91
92 struct Hash {
93 size_t operator()(const SourceRefIndex &c) const;
94 };
95
96 inline size_t getHash() const { return Hash {}(*this); }
97
98private:
105 std::variant<
107 IndexRange>
108 index;
109};
110
111static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefIndex &rhs) {
112 rhs.print(os);
113 return os;
114}
115
128
129public:
131 static std::vector<SourceRef>
132 getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root);
133
135 static std::vector<SourceRef>
137
140 static std::vector<SourceRef>
142
143 explicit SourceRef(mlir::BlockArgument b) : root(b), fieldRefs(), constantVal(std::nullopt) {}
144 SourceRef(mlir::BlockArgument b, std::vector<SourceRefIndex> f)
145 : root(b), fieldRefs(std::move(f)), constantVal(std::nullopt) {}
146
148 : root(createOp), fieldRefs(), constantVal(std::nullopt) {}
149 SourceRef(component::CreateStructOp createOp, std::vector<SourceRefIndex> f)
150 : root(createOp), fieldRefs(std::move(f)), constantVal(std::nullopt) {}
151
152 explicit SourceRef(felt::FeltConstantOp c) : root(std::nullopt), fieldRefs(), constantVal(c) {}
153 explicit SourceRef(mlir::arith::ConstantIndexOp c)
154 : root(std::nullopt), fieldRefs(), constantVal(c) {}
156 : root(std::nullopt), fieldRefs(), constantVal(c) {}
157
158 mlir::Type getType() const;
159
160 bool isConstantFelt() const {
161 return constantVal.has_value() && std::holds_alternative<felt::FeltConstantOp>(*constantVal);
162 }
163 bool isConstantIndex() const {
164 return constantVal.has_value() &&
165 std::holds_alternative<mlir::arith::ConstantIndexOp>(*constantVal);
166 }
167 bool isTemplateConstant() const {
168 return constantVal.has_value() &&
169 std::holds_alternative<polymorphic::ConstReadOp>(*constantVal);
170 }
171 bool isConstant() const { return constantVal.has_value(); }
172 bool isConstantInt() const { return isConstantFelt() || isConstantIndex(); }
173
174 bool isFeltVal() const { return llvm::isa<felt::FeltType>(getType()); }
175 bool isIndexVal() const { return llvm::isa<mlir::IndexType>(getType()); }
176 bool isIntegerVal() const { return llvm::isa<mlir::IntegerType>(getType()); }
177 bool isTypeVarVal() const { return llvm::isa<polymorphic::TypeVarType>(getType()); }
178 bool isScalar() const {
179 return isConstant() || isFeltVal() || isIndexVal() || isIntegerVal() || isTypeVarVal();
180 }
181 bool isSignal() const { return isSignalType(getType()); }
182
183 bool isBlockArgument() const {
184 return root.has_value() && std::holds_alternative<mlir::BlockArgument>(*root);
185 }
186 mlir::BlockArgument getBlockArgument() const {
187 ensure(isBlockArgument(), "is not a block argument");
188 return std::get<mlir::BlockArgument>(*root);
189 }
190 unsigned getInputNum() const { return getBlockArgument().getArgNumber(); }
191
192 bool isCreateStructOp() const {
193 return root.has_value() && std::holds_alternative<component::CreateStructOp>(*root);
194 }
196 ensure(isCreateStructOp(), "is not a create struct op");
197 return std::get<component::CreateStructOp>(*root);
198 }
199
200 llvm::DynamicAPInt getConstantFeltValue() const {
201 ensure(
202 isConstantFelt(), mlir::Twine(mlir::StringRef(__FUNCTION__), " requires a constant felt!")
203 );
204 llvm::APInt i = std::get<felt::FeltConstantOp>(*constantVal).getValue();
205 return toDynamicAPInt(i);
206 }
207 llvm::DynamicAPInt getConstantIndexValue() const {
208 ensure(
209 isConstantIndex(), mlir::Twine(mlir::StringRef(__FUNCTION__), " requires a constant index!")
210 );
211 return llvm::DynamicAPInt(std::get<mlir::arith::ConstantIndexOp>(*constantVal).value());
212 }
213 llvm::DynamicAPInt getConstantValue() const {
214 ensure(
216 mlir::Twine(mlir::StringRef(__FUNCTION__), " requires a constant int type!")
217 );
219 }
220
222 bool isValidPrefix(const SourceRef &prefix) const;
223
228 mlir::FailureOr<std::vector<SourceRefIndex>> getSuffix(const SourceRef &prefix) const;
229
236 mlir::FailureOr<SourceRef> translate(const SourceRef &prefix, const SourceRef &other) const;
237
239 mlir::FailureOr<SourceRef> getParentPrefix() const {
240 if (isConstantFelt() || fieldRefs.empty()) {
241 return mlir::failure();
242 }
243 auto copy = *this;
244 copy.fieldRefs.pop_back();
245 return copy;
246 }
247
249 std::vector<SourceRef>
250 getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const;
251
253 auto copy = *this;
254 copy.fieldRefs.push_back(r);
255 return copy;
256 }
257
259 assert(other.isConstantIndex());
261 }
262
263 const std::vector<SourceRefIndex> &getPieces() const { return fieldRefs; }
264
265 void print(mlir::raw_ostream &os) const;
266 void dump() const { print(llvm::errs()); }
267
268 bool operator==(const SourceRef &rhs) const;
269
270 bool operator!=(const SourceRef &rhs) const { return !(*this == rhs); }
271
272 // required for EquivalenceClasses usage
273 bool operator<(const SourceRef &rhs) const;
274
275 bool operator>(const SourceRef &rhs) const { return rhs < *this; }
276
277 struct Hash {
278 size_t operator()(const SourceRef &val) const;
279 };
280
281private:
291 std::optional<std::variant<mlir::BlockArgument, component::CreateStructOp>> root;
292
293 std::vector<SourceRefIndex> fieldRefs;
294 // using mutable to reduce constant casts for certain get* functions.
295 mutable std::optional<
296 std::variant<felt::FeltConstantOp, mlir::arith::ConstantIndexOp, polymorphic::ConstReadOp>>
297 constantVal;
298};
299
300mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRef &rhs);
301
302/* SourceRefSet */
303
304class SourceRefSet : public std::unordered_set<SourceRef, SourceRef::Hash> {
305 using Base = std::unordered_set<SourceRef, SourceRef::Hash>;
306
307public:
308 using Base::Base;
309
310 SourceRefSet &join(const SourceRefSet &rhs);
311
312 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefSet &rhs);
313};
314
315static_assert(
317 "SourceRefSet must satisfy the ScalarLatticeValue requirements"
318);
319
320} // namespace llzk
321
322namespace llvm {
323
324template <> struct DenseMapInfo<llzk::SourceRef> {
326 return llzk::SourceRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
327 }
329 return llzk::SourceRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
330 }
331 static unsigned getHashValue(const llzk::SourceRef &ref) {
332 if (ref == getEmptyKey() || ref == getTombstoneKey()) {
333 return llvm::hash_value(ref.getBlockArgument().getAsOpaquePointer());
334 }
335 return llzk::SourceRef::Hash {}(ref);
336 }
337 static bool isEqual(const llzk::SourceRef &lhs, const llzk::SourceRef &rhs) { return lhs == rhs; }
338};
339
340} // namespace llvm
This file implements helper methods for constructing DynamicAPInts.
void print(llvm::raw_ostream &os) const
Defines an index into an LLZK object.
Definition SourceRef.h:36
bool operator==(const SourceRefIndex &rhs) const
Definition SourceRef.h:76
bool isIndexRange() const
Definition SourceRef.h:67
bool operator<(const SourceRefIndex &rhs) const
Definition SourceRef.cpp:49
bool operator>(const SourceRefIndex &rhs) const
Definition SourceRef.h:90
size_t getHash() const
Definition SourceRef.h:96
component::FieldDefOp getField() const
Definition SourceRef.h:53
bool isIndex() const
Definition SourceRef.h:61
SourceRefIndex(component::FieldDefOp f)
Definition SourceRef.h:40
SourceRefIndex(const llvm::DynamicAPInt &i)
Definition SourceRef.h:42
SourceRefIndex(SymbolLookupResult< component::FieldDefOp > f)
Definition SourceRef.h:41
SourceRefIndex(const llvm::APInt &low, const llvm::APInt &high)
Definition SourceRef.h:45
SourceRefIndex(const llvm::APInt &i)
Definition SourceRef.h:43
void dump() const
Definition SourceRef.h:73
llvm::DynamicAPInt getIndex() const
Definition SourceRef.h:62
void print(mlir::raw_ostream &os) const
Definition SourceRef.cpp:34
IndexRange getIndexRange() const
Definition SourceRef.h:68
SourceRefIndex(IndexRange r)
Definition SourceRef.h:47
SourceRefIndex(int64_t i)
Definition SourceRef.h:44
bool isField() const
Definition SourceRef.h:49
SourceRefSet & join(const SourceRefSet &rhs)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const SourceRefSet &rhs)
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
SourceRef(component::CreateStructOp createOp, std::vector< SourceRefIndex > f)
Definition SourceRef.h:149
bool isIntegerVal() const
Definition SourceRef.h:176
bool isBlockArgument() const
Definition SourceRef.h:183
llvm::DynamicAPInt getConstantIndexValue() const
Definition SourceRef.h:207
std::vector< SourceRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this SourceRef, assuming this ref is not a scalar.
SourceRef createChild(SourceRef other) const
Definition SourceRef.h:258
mlir::FailureOr< SourceRef > getParentPrefix() const
Create a new reference that is the immediate prefix of this reference if possible.
Definition SourceRef.h:239
void print(mlir::raw_ostream &os) const
mlir::FailureOr< std::vector< SourceRefIndex > > getSuffix(const SourceRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool isScalar() const
Definition SourceRef.h:178
bool operator==(const SourceRef &rhs) const
bool isConstantFelt() const
Definition SourceRef.h:160
component::CreateStructOp getCreateStructOp() const
Definition SourceRef.h:195
SourceRef(felt::FeltConstantOp c)
Definition SourceRef.h:152
bool isValidPrefix(const SourceRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
const std::vector< SourceRefIndex > & getPieces() const
Definition SourceRef.h:263
bool isConstantIndex() const
Definition SourceRef.h:163
SourceRef createChild(SourceRefIndex r) const
Definition SourceRef.h:252
mlir::BlockArgument getBlockArgument() const
Definition SourceRef.h:186
llvm::DynamicAPInt getConstantValue() const
Definition SourceRef.h:213
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void dump() const
Definition SourceRef.h:266
SourceRef(mlir::arith::ConstantIndexOp c)
Definition SourceRef.h:153
bool isIndexVal() const
Definition SourceRef.h:175
SourceRef(mlir::BlockArgument b, std::vector< SourceRefIndex > f)
Definition SourceRef.h:144
bool operator<(const SourceRef &rhs) const
SourceRef(mlir::BlockArgument b)
Definition SourceRef.h:143
mlir::FailureOr< SourceRef > translate(const SourceRef &prefix, const SourceRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
SourceRef(polymorphic::ConstReadOp c)
Definition SourceRef.h:155
bool isTemplateConstant() const
Definition SourceRef.h:167
bool isTypeVarVal() const
Definition SourceRef.h:177
bool isConstant() const
Definition SourceRef.h:171
bool isSignal() const
Definition SourceRef.h:181
bool operator!=(const SourceRef &rhs) const
Definition SourceRef.h:270
SourceRef(component::CreateStructOp createOp)
Definition SourceRef.h:147
llvm::DynamicAPInt getConstantFeltValue() const
Definition SourceRef.h:200
bool operator>(const SourceRef &rhs) const
Definition SourceRef.h:275
bool isFeltVal() const
Definition SourceRef.h:174
unsigned getInputNum() const
Definition SourceRef.h:190
bool isConstantInt() const
Definition SourceRef.h:172
bool isCreateStructOp() const
Definition SourceRef.h:192
mlir::Type getType() const
void ensure(bool condition, const llvm::Twine &errMsg)
DynamicAPInt toDynamicAPInt(StringRef str)
Interval operator<<(const Interval &lhs, const Interval &rhs)
bool isSignalType(Type type)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static bool isEqual(const llzk::SourceRef &lhs, const llzk::SourceRef &rhs)
Definition SourceRef.h:337
static unsigned getHashValue(const llzk::SourceRef &ref)
Definition SourceRef.h:331
static llzk::SourceRef getTombstoneKey()
Definition SourceRef.h:328
static llzk::SourceRef getEmptyKey()
Definition SourceRef.h:325
size_t operator()(const SourceRefIndex &c) const
Definition SourceRef.cpp:72
size_t operator()(const SourceRef &val) const