22#include <mlir/IR/Builders.h>
23#include <mlir/Transforms/InliningUtils.h>
25#include <llvm/Support/Debug.h>
31#define GEN_PASS_DECL_COMPUTECONSTRAINTOPRODUCTPASS
32#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
36#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
42using std::make_unique;
50 if (!computeFunc || !constrainFunc) {
64 std::vector<StructDefOp> alignedStructs;
68 LogicalResult alignCalls(
69 FuncDefOp product, SymbolTableCollection &tables,
82 ModuleOp
mod = getOperation();
85 SymbolTableCollection tables;
87 getAnalysis<LightweightSignalEquivalenceAnalysis>()
109 for (
auto s : alignedStructs) {
110 s.getComputeFuncOp()->erase();
111 s.getConstrainFuncOp()->erase();
116FuncDefOp ComputeConstrainToProductPass::alignFuncs(
120 OpBuilder funcBuilder(compute);
123 FuncDefOp productFunc = funcBuilder.create<FuncDefOp>(
124 funcBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}),
FUNC_NAME_PRODUCT,
127 Block *entryBlock = productFunc.addEntryBlock();
128 funcBuilder.setInsertionPointToStart(entryBlock);
131 llvm::SmallVector<Value> args {productFunc.getArguments()};
134 CallOp computeCall = funcBuilder.
create<CallOp>(funcBuilder.getUnknownLoc(), compute, args);
135 args.insert(args.begin(), computeCall->getResult(0));
136 CallOp constrainCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), constrain, args);
137 funcBuilder.create<ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
140 InlinerInterface inliner(productFunc.getContext());
141 if (failed(inlineCall(inliner, computeCall, compute, &compute.
getBody(),
true))) {
145 if (failed(inlineCall(inliner, constrainCall, constrain, &constrain.
getBody(),
true))) {
149 computeCall->erase();
150 constrainCall->erase();
153 alignedStructs.push_back(root);
156 if (failed(alignCalls(productFunc, tables, equivalence))) {
162LogicalResult ComputeConstrainToProductPass::alignCalls(
163 FuncDefOp product, SymbolTableCollection &tables,
167 llvm::SetVector<CallOp> computeCalls, constrainCalls;
168 product.walk([&](CallOp callOp) {
170 computeCalls.insert(callOp);
172 constrainCalls.insert(callOp);
176 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
180 auto doCallsMatch = [&](CallOp compute, CallOp constrain) ->
bool {
182 llvm::outs() <<
"Asking for equivalence between calls\n"
183 << compute <<
"\nand\n"
184 << constrain <<
"\n\n";
185 llvm::outs() <<
"In block:\n\n" << *compute->getBlock() <<
"\n";
190 if (computeStruct != constrainStruct) {
193 for (
unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
202 for (
auto compute : computeCalls) {
204 auto matches = llvm::filter_to_vector(constrainCalls, [&](CallOp constrain) {
205 return doCallsMatch(compute, constrain);
208 if (matches.size() == 1) {
209 alignedCalls.insert({compute, matches[0]});
210 computeCalls.remove(compute);
211 constrainCalls.remove(matches[0]);
216 if (!computeCalls.empty() && constrainCalls.empty()) {
217 product->emitError() <<
"failed to align some @" <<
FUNC_NAME_COMPUTE <<
" and @"
222 for (
auto [compute, constrain] : alignedCalls) {
224 auto newRoot = compute.
getCalleeTarget(tables)->get()->getParentOfType<StructDefOp>();
226 FuncDefOp newProduct = alignFuncs(
227 newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp(), tables, equivalence
234 OpBuilder callBuilder(compute);
235 CallOp newCall = callBuilder.create<CallOp>(
236 callBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), newProduct,
237 compute.getOperands()
239 compute->replaceAllUsesWith(newCall.getResults());
248 return make_unique<ComputeConstrainToProductPass>();
void runOnOperation() override
bool areSignalsEquivalent(mlir::Value v1, mlir::Value v2)
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::FunctionType getFunctionType()
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
::mlir::Pass::Option< std::string > rootStruct
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
constexpr char FUNC_NAME_CONSTRAIN[]
std::unique_ptr< mlir::Pass > createComputeConstrainToProductPass()
constexpr char FUNC_NAME_PRODUCT[]
bool isValidRoot(StructDefOp root)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)