LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKComputeConstrainToProductPass.cpp
Go to the documentation of this file.
1//===-- LLZKComputeConstrainToProductPass.cpp -------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
19#include "llzk/Util/Constants.h"
22#include <mlir/IR/Builders.h>
23#include <mlir/Transforms/InliningUtils.h>
24
25#include <llvm/Support/Debug.h>
26
27#include <memory>
28#include <ranges>
29
30namespace llzk {
31#define GEN_PASS_DECL_COMPUTECONSTRAINTOPRODUCTPASS
32#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
34} // namespace llzk
36#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
38using namespace llzk::component;
39using namespace llzk::function;
40using namespace mlir;
41
42using std::make_unique;
43
44namespace llzk {
47 FuncDefOp computeFunc = root.getComputeFuncOp();
48 FuncDefOp constrainFunc = root.getConstrainFuncOp();
49
50 if (!computeFunc || !constrainFunc) {
51 root->emitError() << "no " << FUNC_NAME_COMPUTE << "/" << FUNC_NAME_CONSTRAIN << " to align";
52 return false;
53 }
54
56 /// root to start aligning from (issue #241)
57
58 return true;
59}
60
62 : public llzk::impl::ComputeConstrainToProductPassBase<ComputeConstrainToProductPass> {
63
64 std::vector<StructDefOp> alignedStructs;
65
66 // Given a @product function body, try to match up calls to @A::@compute and @A::@constrain for
67 // every sub-struct @A and replace them with a call to @A::@product
68 LogicalResult alignCalls(
69 FuncDefOp product, SymbolTableCollection &tables,
71 );
72
73 // Given a StructDefOp @root, replace the @root::@compute and @root::@constrain functions with a
74 // @root::@product
75 FuncDefOp alignFuncs(
76 StructDefOp root, FuncDefOp compute, FuncDefOp constrain, SymbolTableCollection &tables,
78 );
79
80public:
81 void runOnOperation() override {
82 ModuleOp mod = getOperation();
83 StructDefOp root;
84
85 SymbolTableCollection tables;
87 getAnalysis<LightweightSignalEquivalenceAnalysis>()
88 };
89
90 // Find the indicated root struct and make sure its a valid place to start aligning
91 mod.walk([&root, this](StructDefOp structDef) {
92 if (structDef.getSymName() == rootStruct) {
93 root = structDef;
94 }
95 });
96 if (!isValidRoot(root)) {
97 signalPassFailure();
98 return;
99 }
100
101 // Try aligning the root functions
102 if (!alignFuncs(
103 root, root.getComputeFuncOp(), root.getConstrainFuncOp(), tables, equivalence
104 )) {
105 signalPassFailure();
106 return;
107 }
108
109 for (auto s : alignedStructs) {
110 s.getComputeFuncOp()->erase();
111 s.getConstrainFuncOp()->erase();
112 }
113 }
114};
115
116FuncDefOp ComputeConstrainToProductPass::alignFuncs(
117 StructDefOp root, FuncDefOp compute, FuncDefOp constrain, SymbolTableCollection &tables,
119) {
120 OpBuilder funcBuilder(compute);
121
122 // Create an empty @product func...
123 FuncDefOp productFunc = funcBuilder.create<FuncDefOp>(
124 funcBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), FUNC_NAME_PRODUCT,
125 compute.getFunctionType()
126 );
127 Block *entryBlock = productFunc.addEntryBlock();
128 funcBuilder.setInsertionPointToStart(entryBlock);
129
130 // ...with the right arguments
131 llvm::SmallVector<Value> args {productFunc.getArguments()};
132
133 // Add calls to @compute and @constrain...
134 CallOp computeCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), compute, args);
135 args.insert(args.begin(), computeCall->getResult(0));
136 CallOp constrainCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), constrain, args);
137 funcBuilder.create<ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
138
139 // ..and inline them
140 InlinerInterface inliner(productFunc.getContext());
141 if (failed(inlineCall(inliner, computeCall, compute, &compute.getBody(), true))) {
142 root->emitError() << "failed to inline " << FUNC_NAME_COMPUTE;
143 return nullptr;
144 }
145 if (failed(inlineCall(inliner, constrainCall, constrain, &constrain.getBody(), true))) {
146 root->emitError() << "failed to inline " << FUNC_NAME_CONSTRAIN;
147 return nullptr;
148 }
149 computeCall->erase();
150 constrainCall->erase();
151
152 // Mark the compute/constrain for deletion
153 alignedStructs.push_back(root);
154
155 // Make sure we can align sub-calls to @compute and @constrain
156 if (failed(alignCalls(productFunc, tables, equivalence))) {
157 return nullptr;
158 }
159 return productFunc;
160}
161
162LogicalResult ComputeConstrainToProductPass::alignCalls(
163 FuncDefOp product, SymbolTableCollection &tables,
165) {
166 // Gather up all the remaining calls to @compute and @constrain
167 llvm::SetVector<CallOp> computeCalls, constrainCalls;
168 product.walk([&](CallOp callOp) {
169 if (callOp.calleeIsStructCompute()) {
170 computeCalls.insert(callOp);
171 } else if (callOp.calleeIsStructConstrain()) {
172 constrainCalls.insert(callOp);
173 }
174 });
175
176 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
177
178 // A @compute matches a @constrain if they belong to the same struct and all their input signals
179 // are pairwise equivalent
180 auto doCallsMatch = [&](CallOp compute, CallOp constrain) -> bool {
181 LLVM_DEBUG({
182 llvm::outs() << "Asking for equivalence between calls\n"
183 << compute << "\nand\n"
184 << constrain << "\n\n";
185 llvm::outs() << "In block:\n\n" << *compute->getBlock() << "\n";
186 });
187
188 auto computeStruct = getPrefixAsSymbolRefAttr(compute.getCallee());
189 auto constrainStruct = getPrefixAsSymbolRefAttr(constrain.getCallee());
190 if (computeStruct != constrainStruct) {
191 return false;
192 }
193 for (unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
194 if (!equivalence.areSignalsEquivalent(compute->getOperand(i), constrain->getOperand(i + 1))) {
195 return false;
196 }
197 }
198
199 return true;
200 };
201
202 for (auto compute : computeCalls) {
203 // If there is exactly one @compute that matches a given @constrain, we can align them
204 auto matches = llvm::filter_to_vector(constrainCalls, [&](CallOp constrain) {
205 return doCallsMatch(compute, constrain);
206 });
207
208 if (matches.size() == 1) {
209 alignedCalls.insert({compute, matches[0]});
210 computeCalls.remove(compute);
211 constrainCalls.remove(matches[0]);
212 }
213 }
214
215 // TODO: If unaligned calls remain, fully inline their structs and continue instead of failing
216 if (!computeCalls.empty() && constrainCalls.empty()) {
217 product->emitError() << "failed to align some @" << FUNC_NAME_COMPUTE << " and @"
219 return failure();
220 }
221
222 for (auto [compute, constrain] : alignedCalls) {
223 // If @A::@compute matches @A::@constrain, recursively align the functions in @A...
224 auto newRoot = compute.getCalleeTarget(tables)->get()->getParentOfType<StructDefOp>();
225 assert(newRoot);
226 FuncDefOp newProduct = alignFuncs(
227 newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp(), tables, equivalence
228 );
229 if (!newProduct) {
230 return failure();
231 }
232
233 // ...and replace the two calls with a single call to @A::@product
234 OpBuilder callBuilder(compute);
235 CallOp newCall = callBuilder.create<CallOp>(
236 callBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), newProduct,
237 compute.getOperands()
238 );
239 compute->replaceAllUsesWith(newCall.getResults());
240 compute->erase();
241 constrain->erase();
242 }
243
244 return success();
245}
246
247std::unique_ptr<mlir::Pass> createComputeConstrainToProductPass() {
248 return make_unique<ComputeConstrainToProductPass>();
249}
250
251} // namespace llzk
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1590
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:433
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:429
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:761
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:755
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:777
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
Definition Ops.h.inc:607
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
std::unique_ptr< mlir::Pass > createComputeConstrainToProductPass()
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:29
bool isValidRoot(StructDefOp root)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)