LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Builders.h
Go to the documentation of this file.
1//===-- Builders.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
13
14#include <mlir/IR/Builders.h>
15#include <mlir/IR/MLIRContext.h>
16
17#include <llvm/ADT/DenseMap.h>
18#include <llvm/ADT/DenseSet.h>
19
20#include <deque>
21#include <unordered_map>
22
23namespace llzk {
24
25inline mlir::Location getUnknownLoc(mlir::MLIRContext *context) {
26 return mlir::UnknownLoc::get(context);
27}
28
29mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
30
31inline mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context) {
32 return createLLZKModule(context, getUnknownLoc(context));
33}
34
35void addLangAttrForLLZKDialect(mlir::ModuleOp mod);
36
42public:
43 ModuleBuilder(mlir::ModuleOp m) : context(m.getContext()), rootModule(m) {}
44
45 /* Builder methods */
46
47 inline mlir::Location getUnknownLoc() { return llzk::getUnknownLoc(context); }
48
50 insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams = -1);
51 inline ModuleBuilder &insertEmptyStruct(std::string_view structName, int numStructParams = -1) {
52 return insertEmptyStruct(structName, getUnknownLoc(), numStructParams);
53 }
54
56 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
57 ) {
58 insertEmptyStruct(structName, structLoc);
59 insertComputeFn(structName, computeLoc);
60 return *this;
61 }
62
63 ModuleBuilder &insertComputeOnlyStruct(std::string_view structName) {
64 auto unk = getUnknownLoc();
65 return insertComputeOnlyStruct(structName, unk, unk);
66 }
67
69 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
70 ) {
71 insertEmptyStruct(structName, structLoc);
72 insertConstrainFn(structName, constrainLoc);
73 return *this;
74 }
75
76 ModuleBuilder &insertConstrainOnlyStruct(std::string_view structName) {
77 auto unk = getUnknownLoc();
78 return insertConstrainOnlyStruct(structName, unk, unk);
79 }
80
82 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
83 mlir::Location constrainLoc, int numStructParams = -1
84 ) {
85 insertEmptyStruct(structName, structLoc, numStructParams);
86 insertComputeFn(structName, computeLoc);
87 insertConstrainFn(structName, constrainLoc);
88 return *this;
89 }
90
92 ModuleBuilder &insertFullStruct(std::string_view structName, int numStructParams = -1) {
93 auto unk = getUnknownLoc();
94 return insertFullStruct(structName, unk, unk, unk, numStructParams);
95 }
96
98 std::string_view structName, mlir::Location structLoc, mlir::Location productLoc
99 ) {
100 insertEmptyStruct(structName, structLoc);
101 insertProductFn(structName, productLoc);
102 return *this;
103 }
104
105 ModuleBuilder &insertProductStruct(std::string_view structName) {
106 auto unk = getUnknownLoc();
107 return insertProductStruct(structName, unk, unk);
108 }
109
114 static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc);
116 inline ModuleBuilder &insertComputeFn(std::string_view structName, mlir::Location loc) {
117 return insertComputeFn(*getStruct(structName), loc);
118 }
119 inline ModuleBuilder &insertComputeFn(std::string_view structName) {
120 return insertComputeFn(structName, getUnknownLoc());
121 }
122
126 static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc);
128 inline ModuleBuilder &insertConstrainFn(std::string_view structName, mlir::Location loc) {
129 return insertConstrainFn(*getStruct(structName), getUnknownLoc());
130 }
131 inline ModuleBuilder &insertConstrainFn(std::string_view structName) {
132 return insertConstrainFn(structName, getUnknownLoc());
133 }
134
139 static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc);
141 inline ModuleBuilder &insertProductFn(std::string_view structName, mlir::Location loc) {
142 return insertProductFn(*getStruct(structName), loc);
143 }
144 inline ModuleBuilder &insertProductFn(std::string_view structName) {
145 return insertProductFn(structName, getUnknownLoc());
146 }
147
154 component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc
155 );
157 insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc) {
158 return insertComputeCall(*getStruct(caller), *getStruct(callee), callLoc);
159 }
160 ModuleBuilder &insertComputeCall(std::string_view caller, std::string_view callee) {
161 return insertComputeCall(caller, callee, getUnknownLoc());
162 }
163
171 component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc,
172 mlir::Location fieldDefLoc
173 );
175 std::string_view caller, std::string_view callee, mlir::Location callLoc,
176 mlir::Location fieldDefLoc
177 ) {
178 return insertConstrainCall(*getStruct(caller), *getStruct(callee), callLoc, fieldDefLoc);
179 }
180 ModuleBuilder &insertConstrainCall(std::string_view caller, std::string_view callee) {
181 return insertConstrainCall(caller, callee, getUnknownLoc(), getUnknownLoc());
182 }
183
185 insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
186 inline ModuleBuilder &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type) {
187 return insertFreeFunc(funcName, type, getUnknownLoc());
188 }
189
191 insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc);
192 ModuleBuilder &insertFreeCall(function::FuncDefOp caller, std::string_view callee) {
193 return insertFreeCall(caller, callee, getUnknownLoc());
194 }
195
196 /* Getter methods */
197
199 mlir::ModuleOp &getRootModule() { return rootModule; }
200
201 mlir::FailureOr<component::StructDefOp> getStruct(std::string_view structName) const {
202 if (structMap.find(structName) != structMap.end()) {
203 return structMap.at(structName);
204 }
205 return mlir::failure();
206 }
207
208 mlir::FailureOr<function::FuncDefOp> getComputeFn(std::string_view structName) const {
209 if (computeFnMap.find(structName) != computeFnMap.end()) {
210 return computeFnMap.at(structName);
211 }
212 return mlir::failure();
213 }
214 inline mlir::FailureOr<function::FuncDefOp> getComputeFn(component::StructDefOp op) const {
215 return getComputeFn(op.getName());
216 }
217
218 mlir::FailureOr<function::FuncDefOp> getConstrainFn(std::string_view structName) const {
219 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
220 return constrainFnMap.at(structName);
221 }
222 return mlir::failure();
223 }
224 inline mlir::FailureOr<function::FuncDefOp> getConstrainFn(component::StructDefOp op) const {
225 return getConstrainFn(op.getName());
226 }
227
228 mlir::FailureOr<function::FuncDefOp> getProductFn(std::string_view structName) const {
229 if (productFnMap.find(structName) != productFnMap.end()) {
230 return productFnMap.at(structName);
231 }
232 return mlir::failure();
233 }
234 inline mlir::FailureOr<function::FuncDefOp> getProductFn(component::StructDefOp op) const {
235 return getProductFn(op.getName());
236 }
237
238 mlir::FailureOr<function::FuncDefOp> getFreeFunc(std::string_view funcName) const {
239 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
240 return freeFuncMap.at(funcName);
241 }
242 return mlir::failure();
243 }
244
245 inline mlir::FailureOr<function::FuncDefOp>
246 getFunc(function::FunctionKind kind, std::string_view name) const {
247 switch (kind) {
249 return getComputeFn(name);
251 return getConstrainFn(name);
253 return getProductFn(name);
255 return getFreeFunc(name);
256 }
257 return mlir::failure();
258 }
259
260 /* Helper functions */
261
266 return isReachable(computeNodes, caller, callee);
267 }
268 bool computeReachable(std::string_view caller, std::string_view callee) {
269 return computeReachable(*getStruct(caller), *getStruct(callee));
270 }
271
276 return isReachable(constrainNodes, caller, callee);
277 }
278 bool constrainReachable(std::string_view caller, std::string_view callee) {
279 return constrainReachable(*getStruct(caller), *getStruct(callee));
280 }
281
282private:
283 mlir::MLIRContext *context;
284 mlir::ModuleOp rootModule;
285
286 struct CallNode {
287 mlir::DenseMap<component::StructDefOp, CallNode *> callees;
288 };
289
290 using Def2NodeMap = mlir::DenseMap<component::StructDefOp, CallNode>;
291 using StructDefSet = mlir::DenseSet<component::StructDefOp>;
292
293 Def2NodeMap computeNodes, constrainNodes;
294
295 std::unordered_map<std::string_view, function::FuncDefOp> freeFuncMap;
296 std::unordered_map<std::string_view, component::StructDefOp> structMap;
297 std::unordered_map<std::string_view, function::FuncDefOp> computeFnMap;
298 std::unordered_map<std::string_view, function::FuncDefOp> constrainFnMap;
299 std::unordered_map<std::string_view, function::FuncDefOp> productFnMap;
300
304 void ensureNoSuchFreeFunc(std::string_view funcName);
305
309 void ensureFreeFnExists(std::string_view funcName);
310
314 void ensureNoSuchStruct(std::string_view structName);
315
319 void ensureNoSuchComputeFn(std::string_view structName);
320
324 void ensureComputeFnExists(std::string_view structName);
325
329 void ensureNoSuchConstrainFn(std::string_view structName);
330
334 void ensureConstrainFnExists(std::string_view structName);
335
339 void ensureNoSuchProductFn(std::string_view structName);
340
344 void ensureProductFnExists(std::string_view structName);
345
346 void updateComputeReachability(component::StructDefOp caller, component::StructDefOp callee) {
347 updateReachability(computeNodes, caller, callee);
348 }
349
350 void updateConstrainReachability(component::StructDefOp caller, component::StructDefOp callee) {
351 updateReachability(constrainNodes, caller, callee);
352 }
353
354 void
355 updateReachability(Def2NodeMap &m, component::StructDefOp caller, component::StructDefOp callee) {
356 auto &callerNode = m[caller];
357 auto &calleeNode = m[callee];
358 callerNode.callees[callee] = &calleeNode;
359 }
360
361 bool isReachable(Def2NodeMap &m, component::StructDefOp caller, component::StructDefOp callee) {
362 StructDefSet visited;
363 std::deque<component::StructDefOp> frontier;
364 frontier.push_back(caller);
365
366 while (!frontier.empty()) {
367 auto s = frontier.front();
368 frontier.pop_front();
369 if (!visited.insert(s).second) {
370 continue;
371 }
372
373 if (s == callee) {
374 return true;
375 }
376 for (auto &[calleeStruct, _] : m[s].callees) {
377 frontier.push_back(calleeStruct);
378 }
379 }
380 return false;
381 }
382};
383
384} // namespace llzk
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:41
ModuleBuilder & insertEmptyStruct(std::string_view structName, int numStructParams=-1)
Definition Builders.h:51
mlir::FailureOr< function::FuncDefOp > getProductFn(std::string_view structName) const
Definition Builders.h:228
ModuleBuilder(mlir::ModuleOp m)
Definition Builders.h:43
ModuleBuilder & insertProductStruct(std::string_view structName, mlir::Location structLoc, mlir::Location productLoc)
Definition Builders.h:97
bool constrainReachable(std::string_view caller, std::string_view callee)
Definition Builders.h:278
bool computeReachable(std::string_view caller, std::string_view callee)
Definition Builders.h:268
ModuleBuilder & insertConstrainFn(std::string_view structName, mlir::Location loc)
Definition Builders.h:128
ModuleBuilder & insertProductFn(std::string_view structName)
Definition Builders.h:144
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee)
Definition Builders.h:180
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee)
Definition Builders.h:160
mlir::FailureOr< function::FuncDefOp > getConstrainFn(component::StructDefOp op) const
Definition Builders.h:224
mlir::FailureOr< function::FuncDefOp > getFunc(function::FunctionKind kind, std::string_view name) const
Definition Builders.h:246
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName)
Definition Builders.h:63
ModuleBuilder & insertProductFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertConstrainFn(std::string_view structName)
Definition Builders.h:131
mlir::FailureOr< function::FuncDefOp > getProductFn(component::StructDefOp op) const
Definition Builders.h:234
ModuleBuilder & insertComputeFn(std::string_view structName)
Definition Builders.h:119
bool constrainReachable(component::StructDefOp caller, component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
Definition Builders.h:275
mlir::ModuleOp & getRootModule()
Get the top-level LLZK module.
Definition Builders.h:199
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName) const
Definition Builders.h:218
ModuleBuilder & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
ModuleBuilder & insertProductFn(std::string_view structName, mlir::Location loc)
Definition Builders.h:141
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
Definition Builders.h:174
ModuleBuilder & insertComputeFn(std::string_view structName, mlir::Location loc)
Definition Builders.h:116
mlir::FailureOr< function::FuncDefOp > getComputeFn(component::StructDefOp op) const
Definition Builders.h:214
ModuleBuilder & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type)
Definition Builders.h:186
bool computeReachable(component::StructDefOp caller, component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
Definition Builders.h:265
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
Definition Builders.h:208
mlir::FailureOr< function::FuncDefOp > getFreeFunc(std::string_view funcName) const
Definition Builders.h:238
mlir::Location getUnknownLoc()
Definition Builders.h:47
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
Definition Builders.h:157
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
Definition Builders.cpp:161
ModuleBuilder & insertFreeCall(function::FuncDefOp caller, std::string_view callee)
Definition Builders.h:192
ModuleBuilder & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc, int numStructParams=-1)
Definition Builders.h:81
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
Definition Builders.h:68
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
Definition Builders.cpp:143
ModuleBuilder & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertFullStruct(std::string_view structName, int numStructParams=-1)
Inserts a struct with both compute and constrain functions.
Definition Builders.h:92
ModuleBuilder & insertComputeFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
Definition Builders.h:55
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName)
Definition Builders.h:76
ModuleBuilder & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
ModuleBuilder & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Definition Builders.cpp:125
ModuleBuilder & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
mlir::FailureOr< component::StructDefOp > getStruct(std::string_view structName) const
Definition Builders.h:201
ModuleBuilder & insertProductStruct(std::string_view structName)
Definition Builders.h:105
FunctionKind
Kinds of functions in LLZK.
Definition Ops.h:32
@ StructConstrain
Function within a struct named FUNC_NAME_CONSTRAIN.
Definition Ops.h:36
@ StructProduct
Function within a struct named FUNC_NAME_PRODUCT.
Definition Ops.h:38
@ StructCompute
Function within a struct named FUNC_NAME_COMPUTE.
Definition Ops.h:34
@ Free
Function that is not within a struct.
Definition Ops.h:40
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
Definition Builders.cpp:28
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)
Definition Builders.cpp:22
mlir::Location getUnknownLoc(mlir::MLIRContext *context)
Definition Builders.h:25