LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Builders.cpp
Go to the documentation of this file.
1//===-- Builders.cpp - Operation builder implementations --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13
14#include <llvm/Support/ErrorHandling.h>
15
16namespace llzk {
17
18using namespace mlir;
19using namespace component;
20using namespace function;
21
22OwningOpRef<ModuleOp> createLLZKModule(MLIRContext *context, Location loc) {
23 auto mod = ModuleOp::create(loc);
25 return mod;
26}
27
28void addLangAttrForLLZKDialect(mlir::ModuleOp mod) {
29 MLIRContext *ctx = mod.getContext();
30 if (auto dialect = ctx->getOrLoadDialect<LLZKDialect>()) {
31 mod->setAttr(LANG_ATTR_NAME, StringAttr::get(ctx, dialect->getNamespace()));
32 } else {
33 llvm::report_fatal_error("Could not load LLZK dialect!");
34 }
35}
36
37/* ModuleBuilder */
38
39void ModuleBuilder::ensureNoSuchGlobalFunc(std::string_view funcName) {
40 if (globalFuncMap.find(funcName) != globalFuncMap.end()) {
41 auto error_message = "global function " + Twine(funcName) + " already exists!";
42 llvm::report_fatal_error(error_message);
43 }
44}
45
46void ModuleBuilder::ensureGlobalFnExists(std::string_view funcName) {
47 if (globalFuncMap.find(funcName) == globalFuncMap.end()) {
48 auto error_message = "global function " + Twine(funcName) + " does not exist!";
49 llvm::report_fatal_error(error_message);
50 }
51}
52
53void ModuleBuilder::ensureNoSuchStruct(std::string_view structName) {
54 if (structMap.find(structName) != structMap.end()) {
55 auto error_message = "struct " + Twine(structName) + " already exists!";
56 llvm::report_fatal_error(error_message);
57 }
58}
59
60void ModuleBuilder::ensureNoSuchComputeFn(std::string_view structName) {
61 if (computeFnMap.find(structName) != computeFnMap.end()) {
62 auto error_message = "struct " + Twine(structName) + " already has a compute function!";
63 llvm::report_fatal_error(error_message);
64 }
65}
66
67void ModuleBuilder::ensureComputeFnExists(std::string_view structName) {
68 if (computeFnMap.find(structName) == computeFnMap.end()) {
69 auto error_message = "struct " + Twine(structName) + " has no compute function!";
70 llvm::report_fatal_error(error_message);
71 }
72}
73
74void ModuleBuilder::ensureNoSuchConstrainFn(std::string_view structName) {
75 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
76 auto error_message = "struct " + Twine(structName) + " already has a constrain function!";
77 llvm::report_fatal_error(error_message);
78 }
79}
80
81void ModuleBuilder::ensureConstrainFnExists(std::string_view structName) {
82 if (constrainFnMap.find(structName) == constrainFnMap.end()) {
83 auto error_message = "struct " + Twine(structName) + " has no constrain function!";
84 llvm::report_fatal_error(error_message);
85 }
86}
87
89ModuleBuilder::insertEmptyStruct(std::string_view structName, Location loc, int numStructParams) {
90 ensureNoSuchStruct(structName);
91
92 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
93 auto structNameAttr = StringAttr::get(context, structName);
94 ArrayAttr structParams = nullptr;
95 if (numStructParams >= 0) {
96 SmallVector<Attribute> paramNames;
97 for (int i = 0; i < numStructParams; ++i) {
98 paramNames.push_back(FlatSymbolRefAttr::get(context, "T" + std::to_string(i)));
99 }
100 structParams = opBuilder.getArrayAttr(paramNames);
101 }
102 auto structDef = opBuilder.create<StructDefOp>(loc, structNameAttr, structParams);
103 // populate the initial region
104 auto &region = structDef.getRegion();
105 (void)region.emplaceBlock();
106 structMap[structName] = structDef;
107
108 return *this;
109}
110
112 ensureNoSuchComputeFn(op.getName());
113
114 OpBuilder opBuilder(op.getBody());
115
116 auto fnOp = opBuilder.create<FuncDefOp>(
117 loc, StringAttr::get(context, FUNC_NAME_COMPUTE),
118 FunctionType::get(context, {}, {op.getType()})
119 );
120 fnOp.addEntryBlock();
121 computeFnMap[op.getName()] = fnOp;
122 return *this;
123}
124
126 ensureNoSuchConstrainFn(op.getName());
127
128 OpBuilder opBuilder(op.getBody());
129
130 auto fnOp = opBuilder.create<FuncDefOp>(
131 loc, StringAttr::get(context, FUNC_NAME_CONSTRAIN),
132 FunctionType::get(context, {op.getType()}, {})
133 );
134 fnOp.addEntryBlock();
135 constrainFnMap[op.getName()] = fnOp;
136 return *this;
137}
138
140ModuleBuilder::insertComputeCall(StructDefOp caller, StructDefOp callee, Location callLoc) {
141 ensureComputeFnExists(caller.getName());
142 ensureComputeFnExists(callee.getName());
143
144 auto callerFn = computeFnMap.at(caller.getName());
145 auto calleeFn = computeFnMap.at(callee.getName());
146
147 OpBuilder builder(callerFn.getBody());
148 builder.create<CallOp>(callLoc, calleeFn.getResultTypes(), calleeFn.getFullyQualifiedName());
149 updateComputeReachability(caller, callee);
150 return *this;
151}
152
154 StructDefOp caller, StructDefOp callee, Location callLoc, Location fieldDefLoc
155) {
156 ensureConstrainFnExists(caller.getName());
157 ensureConstrainFnExists(callee.getName());
158
159 auto callerFn = constrainFnMap.at(caller.getName());
160 auto calleeFn = constrainFnMap.at(callee.getName());
161 auto calleeTy = callee.getType();
162
163 size_t numOps = 0;
164 for (auto it = caller.getBody().begin(); it != caller.getBody().end(); it++, numOps++)
165 ;
166 auto fieldName = StringAttr::get(context, callee.getName().str() + std::to_string(numOps));
167
168 // Insert the field declaration op
169 {
170 OpBuilder builder(caller.getBody());
171 builder.create<FieldDefOp>(fieldDefLoc, fieldName, calleeTy);
172 }
173
174 // Insert the constrain function ops
175 {
176 OpBuilder builder(callerFn.getBody());
177
178 auto field = builder.create<FieldReadOp>(
179 callLoc, calleeTy,
180 callerFn.getBody().getArgument(0), // first arg is self
181 fieldName
182 );
183 builder.create<CallOp>(
184 callLoc, TypeRange {}, calleeFn.getFullyQualifiedName(), ValueRange {field}
185 );
186 }
187 updateConstrainReachability(caller, callee);
188 return *this;
189}
190
192ModuleBuilder::insertGlobalFunc(std::string_view funcName, FunctionType type, Location loc) {
193 ensureNoSuchGlobalFunc(funcName);
194
195 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
196 auto funcDef = opBuilder.create<FuncDefOp>(loc, funcName, type);
197 (void)funcDef.addEntryBlock();
198 globalFuncMap[funcName] = funcDef;
199
200 return *this;
201}
202
204ModuleBuilder::insertGlobalCall(FuncDefOp caller, std::string_view callee, Location callLoc) {
205 ensureGlobalFnExists(callee);
206 FuncDefOp calleeFn = globalFuncMap.at(callee);
207
208 OpBuilder builder(caller.getBody());
209 builder.create<CallOp>(callLoc, calleeFn.getResultTypes(), calleeFn.getFullyQualifiedName());
210 return *this;
211}
212
213} // namespace llzk
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:41
ModuleBuilder & insertConstrainFn(llzk::component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
ModuleBuilder & insertGlobalCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
ModuleBuilder & insertComputeFn(llzk::component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
ModuleBuilder & insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
ModuleBuilder & insertComputeCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the IR language name.
Definition Constants.h:31
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
Definition Builders.cpp:28
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)
Definition Builders.cpp:22