LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Builders.h
Go to the documentation of this file.
1//===-- Builders.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
13
14#include <mlir/IR/Builders.h>
15#include <mlir/IR/MLIRContext.h>
16
17#include <llvm/ADT/DenseMap.h>
18#include <llvm/ADT/DenseSet.h>
19
20#include <deque>
21#include <unordered_map>
22
23namespace llzk {
24
25inline mlir::Location getUnknownLoc(mlir::MLIRContext *context) {
26 return mlir::UnknownLoc::get(context);
27}
28
29mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
30
31inline mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context) {
32 return createLLZKModule(context, getUnknownLoc(context));
33}
34
35void addLangAttrForLLZKDialect(mlir::ModuleOp mod);
36
42public:
43 ModuleBuilder(mlir::ModuleOp m) : context(m.getContext()), rootModule(m) {}
44
45 /* Builder methods */
46
47 inline mlir::Location getUnknownLoc() { return llzk::getUnknownLoc(context); }
48
50 insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams = -1);
51 inline ModuleBuilder &insertEmptyStruct(std::string_view structName, int numStructParams = -1) {
52 return insertEmptyStruct(structName, getUnknownLoc(), numStructParams);
53 }
54
56 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
57 ) {
58 insertEmptyStruct(structName, structLoc);
59 insertComputeFn(structName, computeLoc);
60 return *this;
61 }
62
63 ModuleBuilder &insertComputeOnlyStruct(std::string_view structName) {
64 auto unk = getUnknownLoc();
65 return insertComputeOnlyStruct(structName, unk, unk);
66 }
67
69 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
70 ) {
71 insertEmptyStruct(structName, structLoc);
72 insertConstrainFn(structName, constrainLoc);
73 return *this;
74 }
75
76 ModuleBuilder &insertConstrainOnlyStruct(std::string_view structName) {
77 auto unk = getUnknownLoc();
78 return insertConstrainOnlyStruct(structName, unk, unk);
79 }
80
82 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
83 mlir::Location constrainLoc, int numStructParams = -1
84 ) {
85 insertEmptyStruct(structName, structLoc, numStructParams);
86 insertComputeFn(structName, computeLoc);
87 insertConstrainFn(structName, constrainLoc);
88 return *this;
89 }
90
92 ModuleBuilder &insertFullStruct(std::string_view structName, int numStructParams = -1) {
93 auto unk = getUnknownLoc();
94 return insertFullStruct(structName, unk, unk, unk, numStructParams);
95 }
96
102 ModuleBuilder &insertComputeFn(std::string_view structName, mlir::Location loc) {
103 return insertComputeFn(*getStruct(structName), loc);
104 }
105 ModuleBuilder &insertComputeFn(std::string_view structName) {
106 return insertComputeFn(structName, getUnknownLoc());
107 }
108
113 ModuleBuilder &insertConstrainFn(std::string_view structName, mlir::Location loc) {
114 return insertConstrainFn(*getStruct(structName), getUnknownLoc());
115 }
116 ModuleBuilder &insertConstrainFn(std::string_view structName) {
117 return insertConstrainFn(structName, getUnknownLoc());
118 }
119
127 mlir::Location callLoc
128 );
130 insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc) {
131 return insertComputeCall(*getStruct(caller), *getStruct(callee), callLoc);
132 }
133 ModuleBuilder &insertComputeCall(std::string_view caller, std::string_view callee) {
134 return insertComputeCall(caller, callee, getUnknownLoc());
135 }
136
145 mlir::Location callLoc, mlir::Location fieldDefLoc
146 );
148 std::string_view caller, std::string_view callee, mlir::Location callLoc,
149 mlir::Location fieldDefLoc
150 ) {
151 return insertConstrainCall(*getStruct(caller), *getStruct(callee), callLoc, fieldDefLoc);
152 }
153 ModuleBuilder &insertConstrainCall(std::string_view caller, std::string_view callee) {
154 return insertConstrainCall(caller, callee, getUnknownLoc(), getUnknownLoc());
155 }
156
158 insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
159 inline ModuleBuilder &insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type) {
160 return insertGlobalFunc(funcName, type, getUnknownLoc());
161 }
162
164 insertGlobalCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc);
165 ModuleBuilder &insertGlobalCall(function::FuncDefOp caller, std::string_view callee) {
166 return insertGlobalCall(caller, callee, getUnknownLoc());
167 }
168
169 /* Getter methods */
170
172 mlir::ModuleOp &getRootModule() { return rootModule; }
173
174 mlir::FailureOr<llzk::component::StructDefOp> getStruct(std::string_view structName) const {
175 if (structMap.find(structName) != structMap.end()) {
176 return structMap.at(structName);
177 }
178 return mlir::failure();
179 }
180
181 mlir::FailureOr<function::FuncDefOp> getComputeFn(std::string_view structName) const {
182 if (computeFnMap.find(structName) != computeFnMap.end()) {
183 return computeFnMap.at(structName);
184 }
185 return mlir::failure();
186 }
187 inline mlir::FailureOr<function::FuncDefOp> getComputeFn(llzk::component::StructDefOp op) const {
188 return getComputeFn(op.getName());
189 }
190
191 mlir::FailureOr<function::FuncDefOp> getConstrainFn(std::string_view structName) {
192 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
193 return constrainFnMap.at(structName);
194 }
195 return mlir::failure();
196 }
197 inline mlir::FailureOr<function::FuncDefOp> getConstrainFn(llzk::component::StructDefOp op) {
198 return getConstrainFn(op.getName());
199 }
200
201 mlir::FailureOr<function::FuncDefOp> getGlobalFunc(std::string_view funcName) const {
202 if (globalFuncMap.find(funcName) != globalFuncMap.end()) {
203 return globalFuncMap.at(funcName);
204 }
205 return mlir::failure();
206 }
207
208 /* Helper functions */
209
214 return isReachable(computeNodes, caller, callee);
215 }
216 bool computeReachable(std::string_view caller, std::string_view callee) {
217 return computeReachable(*getStruct(caller), *getStruct(callee));
218 }
219
223 bool
225 return isReachable(constrainNodes, caller, callee);
226 }
227 bool constrainReachable(std::string_view caller, std::string_view callee) {
228 return constrainReachable(*getStruct(caller), *getStruct(callee));
229 }
230
231private:
232 mlir::MLIRContext *context;
233 mlir::ModuleOp rootModule;
234
235 struct CallNode {
236 mlir::DenseMap<llzk::component::StructDefOp, CallNode *> callees;
237 };
238
239 using Def2NodeMap = mlir::DenseMap<llzk::component::StructDefOp, CallNode>;
240 using StructDefSet = mlir::DenseSet<llzk::component::StructDefOp>;
241
242 Def2NodeMap computeNodes, constrainNodes;
243
244 std::unordered_map<std::string_view, function::FuncDefOp> globalFuncMap;
245 std::unordered_map<std::string_view, llzk::component::StructDefOp> structMap;
246 std::unordered_map<std::string_view, function::FuncDefOp> computeFnMap;
247 std::unordered_map<std::string_view, function::FuncDefOp> constrainFnMap;
248
252 void ensureNoSuchGlobalFunc(std::string_view funcName);
253
257 void ensureGlobalFnExists(std::string_view funcName);
258
262 void ensureNoSuchStruct(std::string_view structName);
263
267 void ensureNoSuchComputeFn(std::string_view structName);
268
272 void ensureComputeFnExists(std::string_view structName);
273
277 void ensureNoSuchConstrainFn(std::string_view structName);
278
282 void ensureConstrainFnExists(std::string_view structName);
283
284 void updateComputeReachability(
286 ) {
287 updateReachability(computeNodes, caller, callee);
288 }
289
290 void updateConstrainReachability(
291 llzk::component::StructDefOp caller, llzk::component::StructDefOp callee
292 ) {
293 updateReachability(constrainNodes, caller, callee);
294 }
295
296 void updateReachability(
297 Def2NodeMap &m, llzk::component::StructDefOp caller, llzk::component::StructDefOp callee
298 ) {
299 auto &callerNode = m[caller];
300 auto &calleeNode = m[callee];
301 callerNode.callees[callee] = &calleeNode;
302 }
303
304 bool isReachable(
305 Def2NodeMap &m, llzk::component::StructDefOp caller, llzk::component::StructDefOp callee
306 ) {
307 StructDefSet visited;
308 std::deque<llzk::component::StructDefOp> frontier;
309 frontier.push_back(caller);
310
311 while (!frontier.empty()) {
312 auto s = frontier.front();
313 frontier.pop_front();
314 if (!visited.insert(s).second) {
315 continue;
316 }
317
318 if (s == callee) {
319 return true;
320 }
321 for (auto &[calleeStruct, _] : m[s].callees) {
322 frontier.push_back(calleeStruct);
323 }
324 }
325 return false;
326 }
327};
328
329} // namespace llzk
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:41
ModuleBuilder & insertEmptyStruct(std::string_view structName, int numStructParams=-1)
Definition Builders.h:51
ModuleBuilder(mlir::ModuleOp m)
Definition Builders.h:43
bool constrainReachable(std::string_view caller, std::string_view callee)
Definition Builders.h:227
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName)
Definition Builders.h:191
bool computeReachable(std::string_view caller, std::string_view callee)
Definition Builders.h:216
ModuleBuilder & insertConstrainFn(std::string_view structName, mlir::Location loc)
Definition Builders.h:113
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee)
Definition Builders.h:153
ModuleBuilder & insertGlobalCall(function::FuncDefOp caller, std::string_view callee)
Definition Builders.h:165
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee)
Definition Builders.h:133
mlir::FailureOr< function::FuncDefOp > getComputeFn(llzk::component::StructDefOp op) const
Definition Builders.h:187
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName)
Definition Builders.h:63
ModuleBuilder & insertConstrainFn(std::string_view structName)
Definition Builders.h:116
ModuleBuilder & insertComputeFn(std::string_view structName)
Definition Builders.h:105
ModuleBuilder & insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type)
Definition Builders.h:159
mlir::ModuleOp & getRootModule()
Get the top-level LLZK module.
Definition Builders.h:172
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
Definition Builders.h:147
ModuleBuilder & insertComputeFn(std::string_view structName, mlir::Location loc)
Definition Builders.h:102
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
Definition Builders.h:181
mlir::Location getUnknownLoc()
Definition Builders.h:47
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
Definition Builders.h:130
mlir::FailureOr< function::FuncDefOp > getGlobalFunc(std::string_view funcName) const
Definition Builders.h:201
ModuleBuilder & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc, int numStructParams=-1)
Definition Builders.h:81
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
Definition Builders.h:68
mlir::FailureOr< function::FuncDefOp > getConstrainFn(llzk::component::StructDefOp op)
Definition Builders.h:197
ModuleBuilder & insertConstrainFn(llzk::component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
ModuleBuilder & insertFullStruct(std::string_view structName, int numStructParams=-1)
Inserts a struct with both compute and constrain functions.
Definition Builders.h:92
ModuleBuilder & insertGlobalCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
ModuleBuilder & insertComputeFn(llzk::component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
bool constrainReachable(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
Definition Builders.h:224
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
Definition Builders.h:55
mlir::FailureOr< llzk::component::StructDefOp > getStruct(std::string_view structName) const
Definition Builders.h:174
ModuleBuilder & insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
ModuleBuilder & insertComputeCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName)
Definition Builders.h:76
ModuleBuilder & insertConstrainCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
bool computeReachable(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
Definition Builders.h:213
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
Definition Builders.cpp:28
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)
Definition Builders.cpp:22
mlir::Location getUnknownLoc(mlir::MLIRContext *context)
Definition Builders.h:25