LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Builders.cpp
Go to the documentation of this file.
1//===-- Builders.cpp - Operation builder implementations --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13
14#include <llvm/Support/ErrorHandling.h>
15
16namespace llzk {
17
18using namespace mlir;
19using namespace component;
20using namespace function;
21
22OwningOpRef<ModuleOp> createLLZKModule(MLIRContext *context, Location loc) {
23 auto mod = ModuleOp::create(loc);
25 return mod;
26}
27
28void addLangAttrForLLZKDialect(mlir::ModuleOp mod) {
29 MLIRContext *ctx = mod.getContext();
30 if (auto dialect = ctx->getOrLoadDialect<LLZKDialect>()) {
31 mod->setAttr(LANG_ATTR_NAME, StringAttr::get(ctx, dialect->getNamespace()));
32 } else {
33 llvm::report_fatal_error("Could not load LLZK dialect!");
34 }
35}
36
37/* ModuleBuilder */
38
39void ModuleBuilder::ensureNoSuchFreeFunc(std::string_view funcName) {
40 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
41 auto error_message = "global function " + Twine(funcName) + " already exists!";
42 llvm::report_fatal_error(error_message);
43 }
44}
45
46void ModuleBuilder::ensureFreeFnExists(std::string_view funcName) {
47 if (freeFuncMap.find(funcName) == freeFuncMap.end()) {
48 auto error_message = "global function " + Twine(funcName) + " does not exist!";
49 llvm::report_fatal_error(error_message);
50 }
51}
52
53void ModuleBuilder::ensureNoSuchStruct(std::string_view structName) {
54 if (structMap.find(structName) != structMap.end()) {
55 auto error_message = "struct " + Twine(structName) + " already exists!";
56 llvm::report_fatal_error(error_message);
57 }
58}
59
60void ModuleBuilder::ensureNoSuchComputeFn(std::string_view structName) {
61 if (computeFnMap.find(structName) != computeFnMap.end()) {
62 auto error_message = "struct " + Twine(structName) + " already has a compute function!";
63 llvm::report_fatal_error(error_message);
64 }
65}
66
67void ModuleBuilder::ensureComputeFnExists(std::string_view structName) {
68 if (computeFnMap.find(structName) == computeFnMap.end()) {
69 auto error_message = "struct " + Twine(structName) + " has no compute function!";
70 llvm::report_fatal_error(error_message);
71 }
72}
73
74void ModuleBuilder::ensureNoSuchConstrainFn(std::string_view structName) {
75 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
76 auto error_message = "struct " + Twine(structName) + " already has a constrain function!";
77 llvm::report_fatal_error(error_message);
78 }
79}
80
81void ModuleBuilder::ensureConstrainFnExists(std::string_view structName) {
82 if (constrainFnMap.find(structName) == constrainFnMap.end()) {
83 auto error_message = "struct " + Twine(structName) + " has no constrain function!";
84 llvm::report_fatal_error(error_message);
85 }
86}
87
88void ModuleBuilder::ensureNoSuchProductFn(std::string_view structName) {
89 if (productFnMap.find(structName) != productFnMap.end()) {
90 auto error_message = "struct " + Twine(structName) + " already has a product function!";
91 llvm::report_fatal_error(error_message);
92 }
93}
94
95void ModuleBuilder::ensureProductFnExists(std::string_view structName) {
96 if (productFnMap.find(structName) == productFnMap.end()) {
97 auto error_message = "struct " + Twine(structName) + " has no product function!";
98 llvm::report_fatal_error(error_message);
99 }
100}
101
103ModuleBuilder::insertEmptyStruct(std::string_view structName, Location loc, int numStructParams) {
104 ensureNoSuchStruct(structName);
105
106 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
107 auto structNameAttr = StringAttr::get(context, structName);
108 ArrayAttr structParams = nullptr;
109 if (numStructParams >= 0) {
110 SmallVector<Attribute> paramNames;
111 for (int i = 0; i < numStructParams; ++i) {
112 paramNames.push_back(FlatSymbolRefAttr::get(context, "T" + std::to_string(i)));
113 }
114 structParams = opBuilder.getArrayAttr(paramNames);
115 }
116 auto structDef = opBuilder.create<StructDefOp>(loc, structNameAttr, structParams);
117 // populate the initial region
118 auto &region = structDef.getRegion();
119 (void)region.emplaceBlock();
120 structMap[structName] = structDef;
121
122 return *this;
123}
124
126 MLIRContext *context = op.getContext();
127 OpBuilder opBuilder(op.getBodyRegion());
128 auto fnOp = opBuilder.create<FuncDefOp>(
129 loc, StringAttr::get(context, FUNC_NAME_COMPUTE),
130 FunctionType::get(context, {}, {op.getType()})
131 );
132 fnOp.setAllowWitnessAttr();
133 fnOp.addEntryBlock();
134 return fnOp;
135}
136
138 ensureNoSuchComputeFn(op.getName());
139 computeFnMap[op.getName()] = buildComputeFn(op, loc);
140 return *this;
141}
142
144 MLIRContext *context = op.getContext();
145 OpBuilder opBuilder(op.getBodyRegion());
146 auto fnOp = opBuilder.create<FuncDefOp>(
147 loc, StringAttr::get(context, FUNC_NAME_CONSTRAIN),
148 FunctionType::get(context, {op.getType()}, {})
149 );
150 fnOp.setAllowConstraintAttr();
151 fnOp.addEntryBlock();
152 return fnOp;
153}
154
156 ensureNoSuchConstrainFn(op.getName());
157 constrainFnMap[op.getName()] = buildConstrainFn(op, loc);
158 return *this;
159}
160
162 MLIRContext *context = op.getContext();
163 OpBuilder opBuilder(op.getBodyRegion());
164 auto fnOp = opBuilder.create<FuncDefOp>(
165 loc, StringAttr::get(context, FUNC_NAME_PRODUCT),
166 FunctionType::get(context, {}, {op.getType()})
167 );
168 fnOp.setAllowWitnessAttr();
169 fnOp.setAllowConstraintAttr();
170 fnOp.addEntryBlock();
171 return fnOp;
172}
173
175 ensureNoSuchProductFn(op.getName());
176 productFnMap[op.getName()] = buildProductFn(op, loc);
177 return *this;
178}
179
180ModuleBuilder &
181ModuleBuilder::insertComputeCall(StructDefOp caller, StructDefOp callee, Location callLoc) {
182 ensureComputeFnExists(caller.getName());
183 ensureComputeFnExists(callee.getName());
184
185 auto callerFn = computeFnMap.at(caller.getName());
186 auto calleeFn = computeFnMap.at(callee.getName());
187
188 OpBuilder builder(callerFn.getBody());
189 builder.create<CallOp>(callLoc, calleeFn);
190 updateComputeReachability(caller, callee);
191 return *this;
192}
193
195 StructDefOp caller, StructDefOp callee, Location callLoc, Location fieldDefLoc
196) {
197 ensureConstrainFnExists(caller.getName());
198 ensureConstrainFnExists(callee.getName());
199
200 FuncDefOp callerFn = constrainFnMap.at(caller.getName());
201 FuncDefOp calleeFn = constrainFnMap.at(callee.getName());
202 StructType calleeTy = callee.getType();
203
204 size_t numOps = caller.getBody()->getOperations().size();
205 auto fieldName = StringAttr::get(context, callee.getName().str() + std::to_string(numOps));
206
207 // Insert the field declaration op
208 {
209 OpBuilder builder(caller.getBodyRegion());
210 builder.create<FieldDefOp>(fieldDefLoc, fieldName, calleeTy);
211 }
212
213 // Insert the constrain function ops
214 {
215 OpBuilder builder(callerFn.getBody());
216
217 auto field = builder.create<FieldReadOp>(
218 callLoc, calleeTy, callerFn.getSelfValueFromConstrain(), fieldName
219 );
220 builder.create<CallOp>(
221 callLoc, TypeRange {}, calleeFn.getFullyQualifiedName(), ValueRange {field}
222 );
223 }
224 updateConstrainReachability(caller, callee);
225 return *this;
226}
227
229ModuleBuilder::insertFreeFunc(std::string_view funcName, FunctionType type, Location loc) {
230 ensureNoSuchFreeFunc(funcName);
231
232 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
233 auto funcDef = opBuilder.create<FuncDefOp>(loc, funcName, type);
234 (void)funcDef.addEntryBlock();
235 freeFuncMap[funcName] = funcDef;
236
237 return *this;
238}
239
241ModuleBuilder::insertFreeCall(FuncDefOp caller, std::string_view callee, Location callLoc) {
242 ensureFreeFnExists(callee);
243 FuncDefOp calleeFn = freeFuncMap.at(callee);
244
245 OpBuilder builder(caller.getBody());
246 builder.create<CallOp>(callLoc, calleeFn);
247 return *this;
248}
249
250} // namespace llzk
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:41
ModuleBuilder & insertProductFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
Definition Builders.cpp:161
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
Definition Builders.cpp:143
ModuleBuilder & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertComputeFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
ModuleBuilder & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Definition Builders.cpp:125
ModuleBuilder & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:1176
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the IR language name.
Definition Constants.h:32
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:29
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
Definition Builders.cpp:28
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)
Definition Builders.cpp:22