LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKPolyLoweringPass.cpp
Go to the documentation of this file.
1//===-- LLZKPolyLoweringPass.cpp --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
20
21#include <mlir/IR/BuiltinOps.h>
22
23#include <llvm/ADT/DenseMap.h>
24#include <llvm/ADT/DenseMapInfo.h>
25#include <llvm/ADT/SmallVector.h>
26#include <llvm/Support/Debug.h>
27
28#include <deque>
29#include <memory>
31// Include the generated base pass class definitions.
32namespace llzk {
33#define GEN_PASS_DECL_POLYLOWERINGPASS
34#define GEN_PASS_DEF_POLYLOWERINGPASS
36} // namespace llzk
38using namespace mlir;
39using namespace llzk;
40using namespace llzk::felt;
41using namespace llzk::function;
42using namespace llzk::component;
43using namespace llzk::constrain;
44
45#define DEBUG_TYPE "llzk-poly-lowering-pass"
46#define AUXILIARY_FIELD_PREFIX "__llzk_poly_lowering_pass_aux_field_"
47
48namespace {
49
50struct AuxAssignment {
51 std::string auxFieldName;
52 Value computedValue;
53};
54
55class PolyLoweringPass : public llzk::impl::PolyLoweringPassBase<PolyLoweringPass> {
56public:
57 void setMaxDegree(unsigned degree) { this->maxDegree = degree; }
59private:
60 unsigned auxCounter = 0;
61
62 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
63 modOp.walk([&structDefs](StructDefOp structDef) {
64 structDefs.push_back(structDef);
65 return WalkResult::skip();
66 });
67 }
68
69 void addAuxField(StructDefOp structDef, StringRef name) {
70 OpBuilder builder(structDef);
71 builder.setInsertionPointToEnd(&structDef.getBody().front());
72 builder.create<FieldDefOp>(
73 structDef.getLoc(), builder.getStringAttr(name), builder.getType<FeltType>()
74 );
75 }
76
77 // Recursively compute degree of FeltOps SSA values
78 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
79 if (memo.count(val)) {
80 return memo[val];
81 }
82 // Handle function parameters (BlockArguments)
83 if (val.isa<BlockArgument>()) {
84 memo[val] = 1;
85 return 1;
86 }
87 if (val.getDefiningOp<FeltConstantOp>()) {
88 return memo[val] = 0;
89 }
90 if (val.getDefiningOp<FeltNonDetOp>()) {
91 return memo[val] = 1;
92 }
93 if (val.getDefiningOp<FieldReadOp>()) {
94 return memo[val] = 1;
95 }
96 if (auto addOp = val.getDefiningOp<AddFeltOp>()) {
97 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
98 }
99 if (auto subOp = val.getDefiningOp<SubFeltOp>()) {
100 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
101 }
102 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
103 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
104 }
105 if (auto divOp = val.getDefiningOp<DivFeltOp>()) {
106 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
107 }
108 if (auto negOp = val.getDefiningOp<NegFeltOp>()) {
109 return memo[val] = getDegree(negOp.getOperand(), memo);
110 }
111
112 llvm_unreachable("Unhandled Felt SSA value in degree computation");
113 }
114
129 void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) {
130 assert(afterOp && "afterOp must be a valid Operation*");
131
132 for (auto &use : llvm::make_early_inc_range(oldVal.getUses())) {
133 Operation *user = use.getOwner();
134
135 // Skip uses that are:
136 // - Before afterOp in the same block.
137 // - Inside afterOp itself.
138 if ((user->getBlock() == afterOp->getBlock()) &&
139 (user->isBeforeInBlock(afterOp) || user == afterOp)) {
140 continue;
141 }
142
143 // Replace this use of oldVal with newVal.
144 use.set(newVal);
145 }
146 }
147
148 Value lowerExpression(
149 Value val, unsigned maxDegree, StructDefOp structDef, FuncDefOp constrainFunc,
150 DenseMap<Value, unsigned> &degreeMemo, DenseMap<Value, Value> &rewrites,
151 SmallVector<AuxAssignment> &auxAssignments
152 ) {
153 if (rewrites.count(val)) {
154 return rewrites[val];
155 }
156
157 unsigned degree = getDegree(val, degreeMemo);
158 if (degree <= maxDegree) {
159 rewrites[val] = val;
160 return val;
161 }
162
163 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
164 // Recursively lower operands first
165 Value lhs = lowerExpression(
166 mulOp.getLhs(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
167 );
168 Value rhs = lowerExpression(
169 mulOp.getRhs(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
170 );
171
172 unsigned lhsDeg = getDegree(lhs, degreeMemo);
173 unsigned rhsDeg = getDegree(rhs, degreeMemo);
174
175 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
176 Value selfVal = constrainFunc.getArgument(0); // %self argument
177 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
178 // Optimization: If lhs == rhs, factor it only once
179 if (lhs == rhs && eraseMul) {
180 std::string auxName = AUXILIARY_FIELD_PREFIX + std::to_string(this->auxCounter++);
181 addAuxField(structDef, auxName);
182
183 auto auxVal = builder.create<FieldReadOp>(
184 lhs.getLoc(), lhs.getType(), selfVal, builder.getStringAttr(auxName)
185 );
186 auxAssignments.push_back({auxName, lhs});
187 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
188 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, lhs);
189
190 // Memoize auxVal as degree 1
191 degreeMemo[auxVal] = 1;
192 rewrites[lhs] = auxVal;
193 rewrites[rhs] = auxVal;
194 // Now selectively replace subsequent uses of lhs with auxVal
195 replaceSubsequentUsesWith(lhs, auxVal, eqOp);
196
197 // Update lhs and rhs to use auxVal
198 lhs = auxVal;
199 rhs = auxVal;
200
201 lhsDeg = rhsDeg = 1;
202 }
203 // While their product exceeds maxDegree, factor out one side
204 while (lhsDeg + rhsDeg > maxDegree) {
205 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
206
207 // Create auxiliary field for toFactor
208 std::string auxName = AUXILIARY_FIELD_PREFIX + std::to_string(this->auxCounter++);
209 addAuxField(structDef, auxName);
210
211 // Read back as FieldReadOp (new SSA value)
212 auto auxVal = builder.create<FieldReadOp>(
213 toFactor.getLoc(), toFactor.getType(), selfVal, builder.getStringAttr(auxName)
214 );
215
216 // Emit constraint: auxVal == toFactor
217 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
218 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, toFactor);
219 auxAssignments.push_back({auxName, toFactor});
220 // Update memoization
221 rewrites[toFactor] = auxVal;
222 degreeMemo[auxVal] = 1; // stays same
223 // replace the term with auxVal.
224 replaceSubsequentUsesWith(toFactor, auxVal, eqOp);
225
226 // Remap toFactor to auxVal for next iterations
227 toFactor = auxVal;
228
229 // Recompute degrees
230 lhsDeg = getDegree(lhs, degreeMemo);
231 rhsDeg = getDegree(rhs, degreeMemo);
232 }
233
234 // Now lhs * rhs fits within degree bound
235 auto mulVal = builder.create<MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
236 if (eraseMul) {
237 mulOp->replaceAllUsesWith(mulVal);
238 mulOp->erase();
239 }
240
241 // Result of this multiply has degree lhsDeg + rhsDeg
242 degreeMemo[mulVal] = lhsDeg + rhsDeg;
243 rewrites[val] = mulVal;
244
245 return mulVal;
246 }
247
248 // For non-mul ops, leave untouched (they're degree-1 safe)
249 rewrites[val] = val;
250 return val;
251 }
252
253 Value getSelfValueFromCompute(FuncDefOp computeFunc) {
254 // Get the single block of the function body
255 Region &body = computeFunc.getBody();
256 assert(!body.empty() && "compute() function body is empty");
257
258 Block &block = body.front();
259
260 // The terminator should be the return op
261 Operation *terminator = block.getTerminator();
262 assert(terminator && "compute() function has no terminator");
263
264 // The return op should be of type ReturnOp
265 auto retOp = dyn_cast<ReturnOp>(terminator);
266 if (!retOp) {
267 llvm::errs() << "Expected ReturnOp as terminator in compute() but found: "
268 << terminator->getName() << "\n";
269 llvm_unreachable("compute() function terminator is not a ReturnOp");
270 }
271
272 // Return its operands as SmallVector<Value>
273 return retOp.getOperands().front();
274 }
275
276 Value rebuildExprInCompute(
277 Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap<Value, Value> &rebuildMemo
278 ) {
279 // Memoized already?
280 if (auto it = rebuildMemo.find(val); it != rebuildMemo.end()) {
281 return it->second;
282 }
283
284 // Case 1: BlockArgument from constrain() -> map to compute()
285 if (auto barg = val.dyn_cast<BlockArgument>()) {
286 unsigned index = barg.getArgNumber(); // Argument index in constrain()
287 Value computeArg = computeFunc.getArgument(index - 1); // Corresponding compute() arg
288 rebuildMemo[val] = computeArg;
289 return computeArg;
290 }
291
292 // Case 2: FieldReadOp in constrain() -> replicate FieldReadOp in compute()
293 if (auto readOp = val.getDefiningOp<FieldReadOp>()) {
294 Value selfVal = getSelfValueFromCompute(computeFunc); // %self is always the return value
295 auto rebuiltRead = builder.create<FieldReadOp>(
296 readOp.getLoc(), readOp.getType(), selfVal, readOp.getFieldNameAttr().getAttr()
297 );
298 rebuildMemo[val] = rebuiltRead.getResult();
299 return rebuiltRead.getResult();
300 }
301
302 // Case 3: AddFeltOp
303 if (auto addOp = val.getDefiningOp<AddFeltOp>()) {
304 Value lhs = rebuildExprInCompute(addOp.getLhs(), computeFunc, builder, rebuildMemo);
305 Value rhs = rebuildExprInCompute(addOp.getRhs(), computeFunc, builder, rebuildMemo);
306 auto rebuiltAdd = builder.create<AddFeltOp>(addOp.getLoc(), addOp.getType(), lhs, rhs);
307 rebuildMemo[val] = rebuiltAdd.getResult();
308 return rebuiltAdd.getResult();
309 }
310
311 // Case 4: SubFeltOp
312 if (auto subOp = val.getDefiningOp<SubFeltOp>()) {
313 Value lhs = rebuildExprInCompute(subOp.getLhs(), computeFunc, builder, rebuildMemo);
314 Value rhs = rebuildExprInCompute(subOp.getRhs(), computeFunc, builder, rebuildMemo);
315 auto rebuiltSub = builder.create<SubFeltOp>(subOp.getLoc(), subOp.getType(), lhs, rhs);
316 rebuildMemo[val] = rebuiltSub.getResult();
317 return rebuiltSub.getResult();
318 }
319
320 // Case 5: MulFeltOp
321 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
322 Value lhs = rebuildExprInCompute(mulOp.getLhs(), computeFunc, builder, rebuildMemo);
323 Value rhs = rebuildExprInCompute(mulOp.getRhs(), computeFunc, builder, rebuildMemo);
324 auto rebuiltMul = builder.create<MulFeltOp>(mulOp.getLoc(), mulOp.getType(), lhs, rhs);
325 rebuildMemo[val] = rebuiltMul.getResult();
326 return rebuiltMul.getResult();
327 }
328
329 // Case 6: NegFeltOp
330 if (auto negOp = val.getDefiningOp<NegFeltOp>()) {
331 Value inner = rebuildExprInCompute(negOp.getOperand(), computeFunc, builder, rebuildMemo);
332 auto rebuiltNeg = builder.create<NegFeltOp>(negOp.getLoc(), negOp.getType(), inner);
333 rebuildMemo[val] = rebuiltNeg.getResult();
334 return rebuiltNeg.getResult();
335 }
336
337 // Case 7: DivFeltOp
338 if (auto divOp = val.getDefiningOp<DivFeltOp>()) {
339 Value lhs = rebuildExprInCompute(divOp.getLhs(), computeFunc, builder, rebuildMemo);
340 Value rhs = rebuildExprInCompute(divOp.getRhs(), computeFunc, builder, rebuildMemo);
341 auto rebuiltDiv = builder.create<DivFeltOp>(divOp.getLoc(), divOp.getType(), lhs, rhs);
342 rebuildMemo[val] = rebuiltDiv.getResult();
343 return rebuiltDiv.getResult();
344 }
345
346 // Case 8: ConstFeltOp
347 if (auto constOp = val.getDefiningOp<FeltConstantOp>()) {
348 auto newConst = builder.create<FeltConstantOp>(constOp.getLoc(), constOp.getValue());
349 rebuildMemo[val] = newConst.getResult();
350 return newConst.getResult();
351 }
352
353 llvm::errs() << "Unhandled expression kind in rebuildExprInCompute: " << val << "\n";
354 llvm_unreachable("Unsupported op in rebuildExprInCompute");
355 }
356
357 // Throw an error if the struct has a field that matches the prefix of the auxiliary fields
358 // we use in the pass. There **shouldn't** be a conflict but just in case let's throw the check.
359 void checkForAuxFieldConflicts(StructDefOp structDef) {
360 structDef.walk([&](FieldDefOp fieldDefOp) {
361 if (fieldDefOp.getName().starts_with(AUXILIARY_FIELD_PREFIX)) {
362 fieldDefOp.emitError() << "Field name: \"" << fieldDefOp.getName()
363 << "\" starts with prefix: \"" << AUXILIARY_FIELD_PREFIX
364 << "\" which is reserved for lowering pass";
365 signalPassFailure();
366 return;
367 }
368 });
369 }
370
371 void runOnOperation() override {
372 ModuleOp moduleOp = getOperation();
373
374 // Validate degree parameter
375 if (maxDegree < 2) {
376 moduleOp.emitError() << "Invalid max degree: " << maxDegree.getValue() << ". Must be >= 2.";
377 signalPassFailure();
378 return;
379 }
380
381 moduleOp.walk([&](StructDefOp structDef) {
382 FuncDefOp constrainFunc = structDef.getConstrainFuncOp();
383 FuncDefOp computeFunc = structDef.getComputeFuncOp();
384 if (!constrainFunc) {
385 structDef.emitOpError() << "\"" << structDef.getName() << "\" doesn't have a '"
386 << FUNC_NAME_CONSTRAIN << "' function";
387 signalPassFailure();
388 return;
389 }
390
391 if (!computeFunc) {
392 structDef.emitOpError() << "\"" << structDef.getName() << "\" doesn't have a '"
393 << FUNC_NAME_COMPUTE << "' function";
394 signalPassFailure();
395 return;
396 }
397
398 checkForAuxFieldConflicts(structDef);
399
400 DenseMap<Value, unsigned> degreeMemo;
401 DenseMap<Value, Value> rewrites;
402 SmallVector<AuxAssignment> auxAssignments;
403
404 // Lower equality constraints
405 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
406 auto &lhsOperand = constraintOp.getLhsMutable();
407 auto &rhsOperand = constraintOp.getRhsMutable();
408 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
409 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
410
411 if (degreeLhs > maxDegree) {
412 Value loweredExpr = lowerExpression(
413 lhsOperand.get(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites,
414 auxAssignments
415 );
416 lhsOperand.set(loweredExpr);
417 }
418 if (degreeRhs > maxDegree) {
419 Value loweredExpr = lowerExpression(
420 rhsOperand.get(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites,
421 auxAssignments
422 );
423 rhsOperand.set(loweredExpr);
424 }
425 });
426
427 // The pass doesn't currently support EmitContainmentOp as it depends on
428 // https://veridise.atlassian.net/browse/LLZK-245 being fixed Once this is fixed, the op
429 // should lower all the elements in the row being looked up
430 constrainFunc.walk([&](EmitContainmentOp containOp) {
431 moduleOp.emitError() << "EmitContainmentOp is unsupported for now in the lowering pass";
432 signalPassFailure();
433 return;
434 });
435
436 // Lower function call arguments
437 constrainFunc.walk([&](CallOp callOp) {
438 if (callOp.calleeIsStructConstrain()) {
439 SmallVector<Value> newOperands = llvm::to_vector(callOp.getArgOperands());
440 bool modified = false;
441
442 for (Value &arg : newOperands) {
443 unsigned deg = getDegree(arg, degreeMemo);
444
445 if (deg > 1) {
446 Value loweredArg = lowerExpression(
447 arg, maxDegree, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
448 );
449 arg = loweredArg;
450 modified = true;
451 }
452 }
453
454 if (modified) {
455 SmallVector<ValueRange> mapOperands;
456 OpBuilder builder(callOp);
457 for (auto group : callOp.getMapOperands()) {
458 mapOperands.push_back(group);
459 }
460
461 builder.create<CallOp>(
462 callOp.getLoc(), callOp.getResultTypes(), callOp.getCallee(), mapOperands,
463 callOp.getNumDimsPerMap(), newOperands
464 );
465 callOp->erase();
466 }
467 }
468 });
469
470 DenseMap<Value, Value> rebuildMemo;
471 Block &computeBlock = computeFunc.getBody().front();
472 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
473 Value selfVal = getSelfValueFromCompute(computeFunc);
474
475 for (const auto &assign : auxAssignments) {
476 Value rebuiltExpr =
477 rebuildExprInCompute(assign.computedValue, computeFunc, builder, rebuildMemo);
478 builder.create<FieldWriteOp>(
479 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName),
480 rebuiltExpr
481 );
482 }
483 });
484 }
485};
486} // namespace
487
488std::unique_ptr<mlir::Pass> llzk::createPolyLoweringPass() {
489 return std::make_unique<PolyLoweringPass>();
490};
491
492std::unique_ptr<mlir::Pass> llzk::createPolyLoweringPass(unsigned maxDegree) {
493 auto pass = std::make_unique<PolyLoweringPass>();
494 static_cast<PolyLoweringPass *>(pass.get())->setMaxDegree(maxDegree);
495 return pass;
496}
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
#define AUXILIARY_FIELD_PREFIX
::mlir::Region & getBody()
Definition Ops.cpp.inc:1810
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
Definition Ops.cpp:142
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
Definition Ops.cpp:353
::mlir::OpOperand & getRhsMutable()
Definition Ops.cpp.inc:293
::mlir::OpOperand & getLhsMutable()
Definition Ops.cpp.inc:288
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:694
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:535
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:27
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
std::unique_ptr< mlir::Pass > createPolyLoweringPass()