14#include <mlir/IR/Builders.h>
15#include <mlir/IR/MLIRContext.h>
17#include <llvm/ADT/DenseMap.h>
18#include <llvm/ADT/DenseSet.h>
21#include <unordered_map>
26 return mlir::UnknownLoc::get(context);
29mlir::OwningOpRef<mlir::ModuleOp>
createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
43 ModuleBuilder(mlir::ModuleOp m) : context(m.getContext()), rootModule(m) {}
50 insertEmptyStruct(std::string_view structName, mlir::Location loc,
int numStructParams = -1);
56 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
69 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
82 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
83 mlir::Location constrainLoc,
int numStructParams = -1
98 std::string_view structName, mlir::Location structLoc, mlir::Location productLoc
172 mlir::Location fieldDefLoc
175 std::string_view caller, std::string_view callee, mlir::Location callLoc,
176 mlir::Location fieldDefLoc
185 insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
201 mlir::FailureOr<component::StructDefOp>
getStruct(std::string_view structName)
const {
202 if (structMap.find(structName) != structMap.end()) {
203 return structMap.at(structName);
205 return mlir::failure();
208 mlir::FailureOr<function::FuncDefOp>
getComputeFn(std::string_view structName)
const {
209 if (computeFnMap.find(structName) != computeFnMap.end()) {
210 return computeFnMap.at(structName);
212 return mlir::failure();
218 mlir::FailureOr<function::FuncDefOp>
getConstrainFn(std::string_view structName)
const {
219 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
220 return constrainFnMap.at(structName);
222 return mlir::failure();
228 mlir::FailureOr<function::FuncDefOp>
getProductFn(std::string_view structName)
const {
229 if (productFnMap.find(structName) != productFnMap.end()) {
230 return productFnMap.at(structName);
232 return mlir::failure();
238 mlir::FailureOr<function::FuncDefOp>
getFreeFunc(std::string_view funcName)
const {
239 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
240 return freeFuncMap.at(funcName);
242 return mlir::failure();
245 inline mlir::FailureOr<function::FuncDefOp>
257 return mlir::failure();
266 return isReachable(computeNodes, caller, callee);
276 return isReachable(constrainNodes, caller, callee);
283 mlir::MLIRContext *context;
284 mlir::ModuleOp rootModule;
287 mlir::DenseMap<component::StructDefOp, CallNode *> callees;
290 using Def2NodeMap = mlir::DenseMap<component::StructDefOp, CallNode>;
291 using StructDefSet = mlir::DenseSet<component::StructDefOp>;
293 Def2NodeMap computeNodes, constrainNodes;
295 std::unordered_map<std::string_view, function::FuncDefOp> freeFuncMap;
296 std::unordered_map<std::string_view, component::StructDefOp> structMap;
297 std::unordered_map<std::string_view, function::FuncDefOp> computeFnMap;
298 std::unordered_map<std::string_view, function::FuncDefOp> constrainFnMap;
299 std::unordered_map<std::string_view, function::FuncDefOp> productFnMap;
304 void ensureNoSuchFreeFunc(std::string_view funcName);
309 void ensureFreeFnExists(std::string_view funcName);
314 void ensureNoSuchStruct(std::string_view structName);
319 void ensureNoSuchComputeFn(std::string_view structName);
324 void ensureComputeFnExists(std::string_view structName);
329 void ensureNoSuchConstrainFn(std::string_view structName);
334 void ensureConstrainFnExists(std::string_view structName);
339 void ensureNoSuchProductFn(std::string_view structName);
344 void ensureProductFnExists(std::string_view structName);
346 void updateComputeReachability(component::StructDefOp caller, component::StructDefOp callee) {
347 updateReachability(computeNodes, caller, callee);
350 void updateConstrainReachability(component::StructDefOp caller, component::StructDefOp callee) {
351 updateReachability(constrainNodes, caller, callee);
355 updateReachability(Def2NodeMap &m, component::StructDefOp caller, component::StructDefOp callee) {
356 auto &callerNode = m[caller];
357 auto &calleeNode = m[callee];
358 callerNode.callees[callee] = &calleeNode;
361 bool isReachable(Def2NodeMap &m, component::StructDefOp caller, component::StructDefOp callee) {
362 StructDefSet visited;
363 std::deque<component::StructDefOp> frontier;
364 frontier.push_back(caller);
366 while (!frontier.empty()) {
367 auto s = frontier.front();
368 frontier.pop_front();
369 if (!visited.insert(s).second) {
376 for (
auto &[calleeStruct, _] : m[s].callees) {
377 frontier.push_back(calleeStruct);
Builds out a LLZK-compliant module and provides utilities for populating that module.
ModuleBuilder & insertEmptyStruct(std::string_view structName, int numStructParams=-1)
mlir::FailureOr< function::FuncDefOp > getProductFn(std::string_view structName) const
ModuleBuilder(mlir::ModuleOp m)
ModuleBuilder & insertProductStruct(std::string_view structName, mlir::Location structLoc, mlir::Location productLoc)
bool constrainReachable(std::string_view caller, std::string_view callee)
bool computeReachable(std::string_view caller, std::string_view callee)
ModuleBuilder & insertConstrainFn(std::string_view structName, mlir::Location loc)
ModuleBuilder & insertProductFn(std::string_view structName)
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee)
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee)
mlir::FailureOr< function::FuncDefOp > getConstrainFn(component::StructDefOp op) const
mlir::FailureOr< function::FuncDefOp > getFunc(function::FunctionKind kind, std::string_view name) const
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName)
ModuleBuilder & insertProductFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertConstrainFn(std::string_view structName)
mlir::FailureOr< function::FuncDefOp > getProductFn(component::StructDefOp op) const
ModuleBuilder & insertComputeFn(std::string_view structName)
bool constrainReachable(component::StructDefOp caller, component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
mlir::ModuleOp & getRootModule()
Get the top-level LLZK module.
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName) const
ModuleBuilder & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
ModuleBuilder & insertProductFn(std::string_view structName, mlir::Location loc)
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
ModuleBuilder & insertComputeFn(std::string_view structName, mlir::Location loc)
mlir::FailureOr< function::FuncDefOp > getComputeFn(component::StructDefOp op) const
ModuleBuilder & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type)
bool computeReachable(component::StructDefOp caller, component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
mlir::FailureOr< function::FuncDefOp > getFreeFunc(std::string_view funcName) const
mlir::Location getUnknownLoc()
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
ModuleBuilder & insertFreeCall(function::FuncDefOp caller, std::string_view callee)
ModuleBuilder & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc, int numStructParams=-1)
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
ModuleBuilder & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertFullStruct(std::string_view structName, int numStructParams=-1)
Inserts a struct with both compute and constrain functions.
ModuleBuilder & insertComputeFn(component::StructDefOp op, mlir::Location loc)
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName)
ModuleBuilder & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
ModuleBuilder & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
ModuleBuilder & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
mlir::FailureOr< component::StructDefOp > getStruct(std::string_view structName) const
ModuleBuilder & insertProductStruct(std::string_view structName)
FunctionKind
Kinds of functions in LLZK.
@ StructConstrain
Function within a struct named FUNC_NAME_CONSTRAIN.
@ StructProduct
Function within a struct named FUNC_NAME_PRODUCT.
@ StructCompute
Function within a struct named FUNC_NAME_COMPUTE.
@ Free
Function that is not within a struct.
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)
mlir::Location getUnknownLoc(mlir::MLIRContext *context)