14#include <mlir/IR/Builders.h>
15#include <mlir/IR/MLIRContext.h>
17#include <llvm/ADT/DenseMap.h>
18#include <llvm/ADT/DenseSet.h>
21#include <unordered_map>
26 return mlir::UnknownLoc::get(context);
29mlir::OwningOpRef<mlir::ModuleOp>
createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
43 ModuleBuilder(mlir::ModuleOp m) : context(m.getContext()), rootModule(m) {}
50 insertEmptyStruct(std::string_view structName, mlir::Location loc,
int numStructParams = -1);
56 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
69 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
82 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
83 mlir::Location constrainLoc,
int numStructParams = -1
127 mlir::Location callLoc
145 mlir::Location callLoc, mlir::Location fieldDefLoc
148 std::string_view caller, std::string_view callee, mlir::Location callLoc,
149 mlir::Location fieldDefLoc
158 insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
174 mlir::FailureOr<llzk::component::StructDefOp>
getStruct(std::string_view structName)
const {
175 if (structMap.find(structName) != structMap.end()) {
176 return structMap.at(structName);
178 return mlir::failure();
181 mlir::FailureOr<function::FuncDefOp>
getComputeFn(std::string_view structName)
const {
182 if (computeFnMap.find(structName) != computeFnMap.end()) {
183 return computeFnMap.at(structName);
185 return mlir::failure();
191 mlir::FailureOr<function::FuncDefOp>
getConstrainFn(std::string_view structName) {
192 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
193 return constrainFnMap.at(structName);
195 return mlir::failure();
201 mlir::FailureOr<function::FuncDefOp>
getGlobalFunc(std::string_view funcName)
const {
202 if (globalFuncMap.find(funcName) != globalFuncMap.end()) {
203 return globalFuncMap.at(funcName);
205 return mlir::failure();
214 return isReachable(computeNodes, caller, callee);
225 return isReachable(constrainNodes, caller, callee);
232 mlir::MLIRContext *context;
233 mlir::ModuleOp rootModule;
236 mlir::DenseMap<llzk::component::StructDefOp, CallNode *> callees;
239 using Def2NodeMap = mlir::DenseMap<llzk::component::StructDefOp, CallNode>;
240 using StructDefSet = mlir::DenseSet<llzk::component::StructDefOp>;
242 Def2NodeMap computeNodes, constrainNodes;
244 std::unordered_map<std::string_view, function::FuncDefOp> globalFuncMap;
245 std::unordered_map<std::string_view, llzk::component::StructDefOp> structMap;
246 std::unordered_map<std::string_view, function::FuncDefOp> computeFnMap;
247 std::unordered_map<std::string_view, function::FuncDefOp> constrainFnMap;
252 void ensureNoSuchGlobalFunc(std::string_view funcName);
257 void ensureGlobalFnExists(std::string_view funcName);
262 void ensureNoSuchStruct(std::string_view structName);
267 void ensureNoSuchComputeFn(std::string_view structName);
272 void ensureComputeFnExists(std::string_view structName);
277 void ensureNoSuchConstrainFn(std::string_view structName);
282 void ensureConstrainFnExists(std::string_view structName);
284 void updateComputeReachability(
287 updateReachability(computeNodes, caller, callee);
290 void updateConstrainReachability(
291 llzk::component::StructDefOp caller, llzk::component::StructDefOp callee
293 updateReachability(constrainNodes, caller, callee);
296 void updateReachability(
297 Def2NodeMap &m, llzk::component::StructDefOp caller, llzk::component::StructDefOp callee
299 auto &callerNode = m[caller];
300 auto &calleeNode = m[callee];
301 callerNode.callees[callee] = &calleeNode;
305 Def2NodeMap &m, llzk::component::StructDefOp caller, llzk::component::StructDefOp callee
307 StructDefSet visited;
308 std::deque<llzk::component::StructDefOp> frontier;
309 frontier.push_back(caller);
311 while (!frontier.empty()) {
312 auto s = frontier.front();
313 frontier.pop_front();
314 if (!visited.insert(s).second) {
321 for (
auto &[calleeStruct, _] : m[s].callees) {
322 frontier.push_back(calleeStruct);
Builds out a LLZK-compliant module and provides utilities for populating that module.
ModuleBuilder & insertEmptyStruct(std::string_view structName, int numStructParams=-1)
ModuleBuilder(mlir::ModuleOp m)
bool constrainReachable(std::string_view caller, std::string_view callee)
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName)
bool computeReachable(std::string_view caller, std::string_view callee)
ModuleBuilder & insertConstrainFn(std::string_view structName, mlir::Location loc)
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee)
ModuleBuilder & insertGlobalCall(function::FuncDefOp caller, std::string_view callee)
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee)
mlir::FailureOr< function::FuncDefOp > getComputeFn(llzk::component::StructDefOp op) const
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName)
ModuleBuilder & insertConstrainFn(std::string_view structName)
ModuleBuilder & insertComputeFn(std::string_view structName)
ModuleBuilder & insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type)
mlir::ModuleOp & getRootModule()
Get the top-level LLZK module.
ModuleBuilder & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
ModuleBuilder & insertComputeFn(std::string_view structName, mlir::Location loc)
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
mlir::Location getUnknownLoc()
ModuleBuilder & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
mlir::FailureOr< function::FuncDefOp > getGlobalFunc(std::string_view funcName) const
ModuleBuilder & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc, int numStructParams=-1)
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
mlir::FailureOr< function::FuncDefOp > getConstrainFn(llzk::component::StructDefOp op)
ModuleBuilder & insertConstrainFn(llzk::component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
ModuleBuilder & insertFullStruct(std::string_view structName, int numStructParams=-1)
Inserts a struct with both compute and constrain functions.
ModuleBuilder & insertGlobalCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
ModuleBuilder & insertComputeFn(llzk::component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
bool constrainReachable(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
ModuleBuilder & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
mlir::FailureOr< llzk::component::StructDefOp > getStruct(std::string_view structName) const
ModuleBuilder & insertGlobalFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
ModuleBuilder & insertComputeCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
ModuleBuilder & insertEmptyStruct(std::string_view structName, mlir::Location loc, int numStructParams=-1)
ModuleBuilder & insertConstrainOnlyStruct(std::string_view structName)
ModuleBuilder & insertConstrainCall(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee, mlir::Location callLoc, mlir::Location fieldDefLoc)
To call a constraint function, you must:
bool computeReachable(llzk::component::StructDefOp caller, llzk::component::StructDefOp callee)
Returns if the callee compute function is reachable by the caller by construction.
void addLangAttrForLLZKDialect(mlir::ModuleOp mod)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *context, Location loc)
mlir::Location getUnknownLoc(mlir::MLIRContext *context)