LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKPolyLoweringPass.cpp
Go to the documentation of this file.
1//===-- LLZKPolyLoweringPass.cpp --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
21
22#include <mlir/IR/BuiltinOps.h>
23
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
28
29#include <deque>
30#include <memory>
32// Include the generated base pass class definitions.
33namespace llzk {
34#define GEN_PASS_DECL_POLYLOWERINGPASS
35#define GEN_PASS_DEF_POLYLOWERINGPASS
37} // namespace llzk
39using namespace mlir;
40using namespace llzk;
41using namespace llzk::felt;
42using namespace llzk::function;
43using namespace llzk::component;
44using namespace llzk::constrain;
45
46#define DEBUG_TYPE "llzk-poly-lowering-pass"
47#define AUXILIARY_FIELD_PREFIX "__llzk_poly_lowering_pass_aux_field_"
48
49namespace {
50
51struct AuxAssignment {
52 std::string auxFieldName;
53 Value computedValue;
54};
55
56class PolyLoweringPass : public llzk::impl::PolyLoweringPassBase<PolyLoweringPass> {
57public:
58 void setMaxDegree(unsigned degree) { this->maxDegree = degree; }
60private:
61 unsigned auxCounter = 0;
62
63 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
64 modOp.walk([&structDefs](StructDefOp structDef) {
65 structDefs.push_back(structDef);
66 return WalkResult::skip();
67 });
68 }
69
70 // Recursively compute degree of FeltOps SSA values
71 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
72 if (auto it = memo.find(val); it != memo.end()) {
73 return it->second;
74 }
75 // Handle function parameters (BlockArguments)
76 if (val.isa<BlockArgument>()) {
77 memo[val] = 1;
78 return 1;
79 }
80 if (val.getDefiningOp<FeltConstantOp>()) {
81 return memo[val] = 0;
82 }
83 if (val.getDefiningOp<FeltNonDetOp>()) {
84 return memo[val] = 1;
85 }
86 if (val.getDefiningOp<FieldReadOp>()) {
87 return memo[val] = 1;
88 }
89 if (auto addOp = val.getDefiningOp<AddFeltOp>()) {
90 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
91 }
92 if (auto subOp = val.getDefiningOp<SubFeltOp>()) {
93 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
94 }
95 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
96 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
97 }
98 if (auto divOp = val.getDefiningOp<DivFeltOp>()) {
99 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
100 }
101 if (auto negOp = val.getDefiningOp<NegFeltOp>()) {
102 return memo[val] = getDegree(negOp.getOperand(), memo);
103 }
104
105 llvm_unreachable("Unhandled Felt SSA value in degree computation");
106 }
107
108 Value lowerExpression(
109 Value val, StructDefOp structDef, FuncDefOp constrainFunc,
110 DenseMap<Value, unsigned> &degreeMemo, DenseMap<Value, Value> &rewrites,
111 SmallVector<AuxAssignment> &auxAssignments
112 ) {
113 if (rewrites.count(val)) {
114 return rewrites[val];
115 }
116
117 unsigned degree = getDegree(val, degreeMemo);
118 if (degree <= maxDegree) {
119 rewrites[val] = val;
120 return val;
121 }
122
123 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
124 // Recursively lower operands first
125 Value lhs = lowerExpression(
126 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
127 );
128 Value rhs = lowerExpression(
129 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
130 );
131
132 unsigned lhsDeg = getDegree(lhs, degreeMemo);
133 unsigned rhsDeg = getDegree(rhs, degreeMemo);
134
135 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
136 Value selfVal = constrainFunc.getArgument(0); // %self argument
137 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
138 // Optimization: If lhs == rhs, factor it only once
139 if (lhs == rhs && eraseMul) {
140 std::string auxName = AUXILIARY_FIELD_PREFIX + std::to_string(this->auxCounter++);
141 FieldDefOp auxField = addAuxField(structDef, auxName);
142
143 auto auxVal = builder.create<FieldReadOp>(
144 lhs.getLoc(), lhs.getType(), selfVal, auxField.getNameAttr()
145 );
146 auxAssignments.push_back({auxName, lhs});
147 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
148 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, lhs);
149
150 // Memoize auxVal as degree 1
151 degreeMemo[auxVal] = 1;
152 rewrites[lhs] = auxVal;
153 rewrites[rhs] = auxVal;
154 // Now selectively replace subsequent uses of lhs with auxVal
155 replaceSubsequentUsesWith(lhs, auxVal, eqOp);
156
157 // Update lhs and rhs to use auxVal
158 lhs = auxVal;
159 rhs = auxVal;
160
161 lhsDeg = rhsDeg = 1;
162 }
163 // While their product exceeds maxDegree, factor out one side
164 while (lhsDeg + rhsDeg > maxDegree) {
165 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
166
167 // Create auxiliary field for toFactor
168 std::string auxName = AUXILIARY_FIELD_PREFIX + std::to_string(this->auxCounter++);
169 FieldDefOp auxField = addAuxField(structDef, auxName);
170
171 // Read back as FieldReadOp (new SSA value)
172 auto auxVal = builder.create<FieldReadOp>(
173 toFactor.getLoc(), toFactor.getType(), selfVal, auxField.getNameAttr()
174 );
175
176 // Emit constraint: auxVal == toFactor
177 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
178 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, toFactor);
179 auxAssignments.push_back({auxName, toFactor});
180 // Update memoization
181 rewrites[toFactor] = auxVal;
182 degreeMemo[auxVal] = 1; // stays same
183 // replace the term with auxVal.
184 replaceSubsequentUsesWith(toFactor, auxVal, eqOp);
185
186 // Remap toFactor to auxVal for next iterations
187 toFactor = auxVal;
188
189 // Recompute degrees
190 lhsDeg = getDegree(lhs, degreeMemo);
191 rhsDeg = getDegree(rhs, degreeMemo);
192 }
193
194 // Now lhs * rhs fits within degree bound
195 auto mulVal = builder.create<MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
196 if (eraseMul) {
197 mulOp->replaceAllUsesWith(mulVal);
198 mulOp->erase();
199 }
200
201 // Result of this multiply has degree lhsDeg + rhsDeg
202 degreeMemo[mulVal] = lhsDeg + rhsDeg;
203 rewrites[val] = mulVal;
204
205 return mulVal;
206 }
207
208 // For non-mul ops, leave untouched (they're degree-1 safe)
209 rewrites[val] = val;
210 return val;
211 }
212
213 void runOnOperation() override {
214 ModuleOp moduleOp = getOperation();
215
216 // Validate degree parameter
217 if (maxDegree < 2) {
218 auto diag = moduleOp.emitError();
219 diag << "Invalid max degree: " << maxDegree.getValue() << ". Must be >= 2.";
220 diag.report();
221 signalPassFailure();
222 return;
223 }
224
225 moduleOp.walk([this, &moduleOp](StructDefOp structDef) {
226 FuncDefOp constrainFunc = structDef.getConstrainFuncOp();
227 FuncDefOp computeFunc = structDef.getComputeFuncOp();
228 if (!constrainFunc) {
229 auto diag = structDef.emitOpError();
230 diag << '"' << structDef.getName() << "\" doesn't have a \"@" << FUNC_NAME_CONSTRAIN
231 << "\" function";
232 diag.report();
233 signalPassFailure();
234 return;
235 }
236
237 if (!computeFunc) {
238 auto diag = structDef.emitOpError();
239 diag << '"' << structDef.getName() << "\" doesn't have a \"@" << FUNC_NAME_COMPUTE
240 << "\" function";
241 diag.report();
242 signalPassFailure();
243 return;
244 }
245
246 if (failed(checkForAuxFieldConflicts(structDef, AUXILIARY_FIELD_PREFIX))) {
247 signalPassFailure();
248 return;
249 }
250
251 DenseMap<Value, unsigned> degreeMemo;
252 DenseMap<Value, Value> rewrites;
253 SmallVector<AuxAssignment> auxAssignments;
254
255 // Lower equality constraints
256 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
257 auto &lhsOperand = constraintOp.getLhsMutable();
258 auto &rhsOperand = constraintOp.getRhsMutable();
259 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
260 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
261
262 if (degreeLhs > maxDegree) {
263 Value loweredExpr = lowerExpression(
264 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
265 );
266 lhsOperand.set(loweredExpr);
267 }
268 if (degreeRhs > maxDegree) {
269 Value loweredExpr = lowerExpression(
270 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
271 );
272 rhsOperand.set(loweredExpr);
273 }
274 });
275
276 // The pass doesn't currently support EmitContainmentOp as it depends on
277 // https://veridise.atlassian.net/browse/LLZK-245 being fixed Once this is fixed, the op
278 // should lower all the elements in the row being looked up
279 constrainFunc.walk([this, &moduleOp](EmitContainmentOp containOp) {
280 auto diag = moduleOp.emitError();
281 diag << "EmitContainmentOp is unsupported for now in the lowering pass";
282 diag.report();
283 signalPassFailure();
284 return;
285 });
286
287 // Lower function call arguments
288 constrainFunc.walk([&](CallOp callOp) {
289 if (callOp.calleeIsStructConstrain()) {
290 SmallVector<Value> newOperands = llvm::to_vector(callOp.getArgOperands());
291 bool modified = false;
292
293 for (Value &arg : newOperands) {
294 unsigned deg = getDegree(arg, degreeMemo);
295
296 if (deg > 1) {
297 Value loweredArg = lowerExpression(
298 arg, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
299 );
300 arg = loweredArg;
301 modified = true;
302 }
303 }
304
305 if (modified) {
306 SmallVector<ValueRange> mapOperands;
307 OpBuilder builder(callOp);
308 for (auto group : callOp.getMapOperands()) {
309 mapOperands.push_back(group);
310 }
311
312 builder.create<CallOp>(
313 callOp.getLoc(), callOp.getResultTypes(), callOp.getCallee(), mapOperands,
314 callOp.getNumDimsPerMap(), newOperands
315 );
316 callOp->erase();
317 }
318 }
319 });
320
321 DenseMap<Value, Value> rebuildMemo;
322 Block &computeBlock = computeFunc.getBody().front();
323 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
324 Value selfVal = getSelfValueFromCompute(computeFunc);
325
326 for (const auto &assign : auxAssignments) {
327 Value rebuiltExpr =
328 rebuildExprInCompute(assign.computedValue, computeFunc, builder, rebuildMemo);
329 builder.create<FieldWriteOp>(
330 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName),
331 rebuiltExpr
332 );
333 }
334 });
335 }
336};
337} // namespace
338
339std::unique_ptr<mlir::Pass> llzk::createPolyLoweringPass() {
340 return std::make_unique<PolyLoweringPass>();
341};
342
343std::unique_ptr<mlir::Pass> llzk::createPolyLoweringPass(unsigned maxDegree) {
344 auto pass = std::make_unique<PolyLoweringPass>();
345 static_cast<PolyLoweringPass *>(pass.get())->setMaxDegree(maxDegree);
346 return pass;
347}
#define AUXILIARY_FIELD_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
Definition Ops.cpp:353
::mlir::OpOperand & getRhsMutable()
Definition Ops.cpp.inc:293
::mlir::OpOperand & getLhsMutable()
Definition Ops.cpp.inc:288
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:694
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:535
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
Value getSelfValueFromCompute(FuncDefOp computeFunc)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
std::unique_ptr< mlir::Pass > createPolyLoweringPass()