14#include <llvm/Support/ErrorHandling.h>
23 auto mod = ModuleOp::create(loc);
29 MLIRContext *ctx =
mod.getContext();
30 if (
auto dialect = ctx->getOrLoadDialect<
LLZKDialect>()) {
33 llvm::report_fatal_error(
"Could not load LLZK dialect!");
39void ModuleBuilder::ensureNoSuchFreeFunc(std::string_view funcName) {
40 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
41 auto error_message =
"global function " + Twine(funcName) +
" already exists!";
42 llvm::report_fatal_error(error_message);
46void ModuleBuilder::ensureFreeFnExists(std::string_view funcName) {
47 if (freeFuncMap.find(funcName) == freeFuncMap.end()) {
48 auto error_message =
"global function " + Twine(funcName) +
" does not exist!";
49 llvm::report_fatal_error(error_message);
53void ModuleBuilder::ensureNoSuchStruct(std::string_view structName) {
54 if (structMap.find(structName) != structMap.end()) {
55 auto error_message =
"struct " + Twine(structName) +
" already exists!";
56 llvm::report_fatal_error(error_message);
60void ModuleBuilder::ensureNoSuchComputeFn(std::string_view structName) {
61 if (computeFnMap.find(structName) != computeFnMap.end()) {
62 auto error_message =
"struct " + Twine(structName) +
" already has a compute function!";
63 llvm::report_fatal_error(error_message);
67void ModuleBuilder::ensureComputeFnExists(std::string_view structName) {
68 if (computeFnMap.find(structName) == computeFnMap.end()) {
69 auto error_message =
"struct " + Twine(structName) +
" has no compute function!";
70 llvm::report_fatal_error(error_message);
74void ModuleBuilder::ensureNoSuchConstrainFn(std::string_view structName) {
75 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
76 auto error_message =
"struct " + Twine(structName) +
" already has a constrain function!";
77 llvm::report_fatal_error(error_message);
81void ModuleBuilder::ensureConstrainFnExists(std::string_view structName) {
82 if (constrainFnMap.find(structName) == constrainFnMap.end()) {
83 auto error_message =
"struct " + Twine(structName) +
" has no constrain function!";
84 llvm::report_fatal_error(error_message);
88void ModuleBuilder::ensureNoSuchProductFn(std::string_view structName) {
89 if (productFnMap.find(structName) != productFnMap.end()) {
90 auto error_message =
"struct " + Twine(structName) +
" already has a product function!";
91 llvm::report_fatal_error(error_message);
95void ModuleBuilder::ensureProductFnExists(std::string_view structName) {
96 if (productFnMap.find(structName) == productFnMap.end()) {
97 auto error_message =
"struct " + Twine(structName) +
" has no product function!";
98 llvm::report_fatal_error(error_message);
104 ensureNoSuchStruct(structName);
106 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
107 auto structNameAttr = StringAttr::get(context, structName);
108 ArrayAttr structParams =
nullptr;
109 if (numStructParams >= 0) {
110 SmallVector<Attribute> paramNames;
111 for (
int i = 0; i < numStructParams; ++i) {
112 paramNames.push_back(FlatSymbolRefAttr::get(context,
"T" + std::to_string(i)));
114 structParams = opBuilder.getArrayAttr(paramNames);
116 auto structDef = opBuilder.create<StructDefOp>(loc, structNameAttr, structParams);
118 auto ®ion = structDef.getRegion();
119 (void)region.emplaceBlock();
120 structMap[structName] = structDef;
126 MLIRContext *context = op.getContext();
130 FunctionType::get(context, {}, {op.
getType()})
132 fnOp.setAllowWitnessAttr();
133 fnOp.addEntryBlock();
138 ensureNoSuchComputeFn(op.getName());
144 MLIRContext *context = op.getContext();
148 FunctionType::get(context, {op.
getType()}, {})
150 fnOp.setAllowConstraintAttr();
151 fnOp.addEntryBlock();
156 ensureNoSuchConstrainFn(op.getName());
162 MLIRContext *context = op.getContext();
166 FunctionType::get(context, {}, {op.
getType()})
168 fnOp.setAllowWitnessAttr();
169 fnOp.setAllowConstraintAttr();
170 fnOp.addEntryBlock();
175 ensureNoSuchProductFn(op.getName());
182 ensureComputeFnExists(caller.getName());
183 ensureComputeFnExists(callee.getName());
185 auto callerFn = computeFnMap.at(caller.getName());
186 auto calleeFn = computeFnMap.at(callee.getName());
188 OpBuilder builder(callerFn.getBody());
189 builder.create<CallOp>(callLoc, calleeFn);
190 updateComputeReachability(caller, callee);
197 ensureConstrainFnExists(caller.getName());
198 ensureConstrainFnExists(callee.getName());
200 FuncDefOp callerFn = constrainFnMap.at(caller.getName());
201 FuncDefOp calleeFn = constrainFnMap.at(callee.getName());
202 StructType calleeTy = callee.getType();
204 size_t numOps = caller.getBody()->getOperations().size();
205 auto fieldName = StringAttr::get(context, callee.getName().str() + std::to_string(numOps));
209 OpBuilder builder(caller.getBodyRegion());
210 builder.create<FieldDefOp>(fieldDefLoc, fieldName, calleeTy);
215 OpBuilder builder(callerFn.getBody());
217 auto field = builder.create<FieldReadOp>(
218 callLoc, calleeTy, callerFn.getSelfValueFromConstrain(), fieldName
220 builder.create<CallOp>(
221 callLoc, TypeRange {}, calleeFn.getFullyQualifiedName(), ValueRange {field}
224 updateConstrainReachability(caller, callee);
230 ensureNoSuchFreeFunc(funcName);
232 OpBuilder opBuilder(rootModule.getBody(), rootModule.getBody()->begin());
233 auto funcDef = opBuilder.create<FuncDefOp>(loc, funcName, type);
234 (void)funcDef.addEntryBlock();
235 freeFuncMap[funcName] = funcDef;
242 ensureFreeFnExists(callee);
243 FuncDefOp calleeFn = freeFuncMap.at(callee);
245 OpBuilder builder(caller.getBody());
246 builder.create<CallOp>(callLoc, calleeFn);
Builds out a LLZK-compliant module and provides utilities for populating that module.
ModuleBuilder & insertProductFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
ModuleBuilder & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertComputeFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
ModuleBuilder & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
ModuleBuilder & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the IR language name.
constexpr char FUNC_NAME_CONSTRAIN[]
constexpr char FUNC_NAME_PRODUCT[]
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)