14#include <llvm/Support/ErrorHandling.h>
23 auto mod = ModuleOp::create(loc);
29 MLIRContext *ctx =
mod.getContext();
30 if (
auto dialect = ctx->getOrLoadDialect<
LLZKDialect>()) {
33 llvm::report_fatal_error(
"Could not load LLZK dialect!");
39void ModuleBuilder::ensureNoSuchGlobalFunc(std::string_view funcName) {
40 if (globalFuncMap.find(funcName) != globalFuncMap.end()) {
41 auto error_message =
"global function " + Twine(funcName) +
" already exists!";
42 llvm::report_fatal_error(error_message);
46void ModuleBuilder::ensureGlobalFnExists(std::string_view funcName) {
47 if (globalFuncMap.find(funcName) == globalFuncMap.end()) {
48 auto error_message =
"global function " + Twine(funcName) +
" does not exist!";
49 llvm::report_fatal_error(error_message);
53void ModuleBuilder::ensureNoSuchStruct(std::string_view structName) {
54 if (structMap.find(structName) != structMap.end()) {
55 auto error_message =
"struct " + Twine(structName) +
" already exists!";
56 llvm::report_fatal_error(error_message);
60void ModuleBuilder::ensureNoSuchComputeFn(std::string_view structName) {
61 if (computeFnMap.find(structName) != computeFnMap.end()) {
62 auto error_message =
"struct " + Twine(structName) +
" already has a compute function!";
63 llvm::report_fatal_error(error_message);
67void ModuleBuilder::ensureComputeFnExists(std::string_view structName) {
68 if (computeFnMap.find(structName) == computeFnMap.end()) {
69 auto error_message =
"struct " + Twine(structName) +
" has no compute function!";
70 llvm::report_fatal_error(error_message);
74void ModuleBuilder::ensureNoSuchConstrainFn(std::string_view structName) {
75 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
76 auto error_message =
"struct " + Twine(structName) +
" already has a constrain function!";
77 llvm::report_fatal_error(error_message);
81void ModuleBuilder::ensureConstrainFnExists(std::string_view structName) {
82 if (constrainFnMap.find(structName) == constrainFnMap.end()) {
83 auto error_message =
"struct " + Twine(structName) +
" has no constrain function!";
84 llvm::report_fatal_error(error_message);
90 ensureNoSuchStruct(structName);
92 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
93 auto structNameAttr = StringAttr::get(context, structName);
94 ArrayAttr structParams =
nullptr;
95 if (numStructParams >= 0) {
96 SmallVector<Attribute> paramNames;
97 for (
int i = 0; i < numStructParams; ++i) {
98 paramNames.push_back(FlatSymbolRefAttr::get(context,
"T" + std::to_string(i)));
100 structParams = opBuilder.getArrayAttr(paramNames);
102 auto structDef = opBuilder.create<StructDefOp>(loc, structNameAttr, structParams);
104 auto ®ion = structDef.getRegion();
105 (void)region.emplaceBlock();
106 structMap[structName] = structDef;
112 ensureNoSuchComputeFn(op.getName());
114 OpBuilder opBuilder(op.getBody());
116 auto fnOp = opBuilder.create<FuncDefOp>(
118 FunctionType::get(context, {}, {op.getType()})
120 fnOp.addEntryBlock();
121 computeFnMap[op.getName()] = fnOp;
126 ensureNoSuchConstrainFn(op.getName());
128 OpBuilder opBuilder(op.getBody());
130 auto fnOp = opBuilder.create<FuncDefOp>(
132 FunctionType::get(context, {op.getType()}, {})
134 fnOp.addEntryBlock();
135 constrainFnMap[op.getName()] = fnOp;
141 ensureComputeFnExists(caller.getName());
142 ensureComputeFnExists(callee.getName());
144 auto callerFn = computeFnMap.at(caller.getName());
145 auto calleeFn = computeFnMap.at(callee.getName());
147 OpBuilder builder(callerFn.getBody());
148 builder.create<CallOp>(callLoc, calleeFn.getResultTypes(), calleeFn.getFullyQualifiedName());
149 updateComputeReachability(caller, callee);
156 ensureConstrainFnExists(caller.getName());
157 ensureConstrainFnExists(callee.getName());
159 auto callerFn = constrainFnMap.at(caller.getName());
160 auto calleeFn = constrainFnMap.at(callee.getName());
161 auto calleeTy = callee.getType();
164 for (
auto it = caller.getBody().begin(); it != caller.getBody().end(); it++, numOps++)
166 auto fieldName = StringAttr::get(context, callee.getName().str() + std::to_string(numOps));
170 OpBuilder builder(caller.getBody());
171 builder.create<FieldDefOp>(fieldDefLoc, fieldName, calleeTy);
176 OpBuilder builder(callerFn.getBody());
178 auto field = builder.create<FieldReadOp>(
180 callerFn.getBody().getArgument(0),
183 builder.create<CallOp>(
184 callLoc, TypeRange {}, calleeFn.getFullyQualifiedName(), ValueRange {field}
187 updateConstrainReachability(caller, callee);
193 ensureNoSuchGlobalFunc(funcName);
195 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
196 auto funcDef = opBuilder.create<FuncDefOp>(loc, funcName, type);
197 (void)funcDef.addEntryBlock();
198 globalFuncMap[funcName] = funcDef;
205 ensureGlobalFnExists(callee);
206 FuncDefOp calleeFn = globalFuncMap.at(callee);
208 OpBuilder builder(caller.getBody());
209 builder.create<CallOp>(callLoc, calleeFn.getResultTypes(), calleeFn.getFullyQualifiedName());
Builds out a LLZK-compliant module and provides utilities for populating that module.
ModuleBuilder & insertConstrainFn(llzk::component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
ModuleBuilder & insertGlobalCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
ModuleBuilder & insertComputeFn(llzk::component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
ModuleBuilder & insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
ModuleBuilder & insertComputeCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the IR language name.
constexpr char FUNC_NAME_CONSTRAIN[]
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)