21#include <mlir/IR/BuiltinOps.h>
23#include <llvm/ADT/DenseMap.h>
24#include <llvm/ADT/DenseMapInfo.h>
25#include <llvm/ADT/SmallVector.h>
26#include <llvm/Support/Debug.h>
33#define GEN_PASS_DECL_POLYLOWERINGPASS
34#define GEN_PASS_DEF_POLYLOWERINGPASS
45#define DEBUG_TYPE "llzk-poly-lowering-pass"
46#define AUXILIARY_FIELD_PREFIX "__llzk_poly_lowering_pass_aux_field_"
51 std::string auxFieldName;
57 void setMaxDegree(
unsigned degree) { this->maxDegree = degree; }
60 unsigned auxCounter = 0;
62 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
64 structDefs.push_back(structDef);
65 return WalkResult::skip();
69 void addAuxField(
StructDefOp structDef, StringRef name) {
70 OpBuilder builder(structDef);
71 builder.setInsertionPointToEnd(&structDef.
getBody().front());
73 structDef.getLoc(), builder.getStringAttr(name), builder.
getType<
FeltType>()
78 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
79 if (memo.count(val)) {
83 if (val.isa<BlockArgument>()) {
96 if (
auto addOp = val.getDefiningOp<AddFeltOp>()) {
97 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
99 if (
auto subOp = val.getDefiningOp<SubFeltOp>()) {
100 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
102 if (
auto mulOp = val.getDefiningOp<MulFeltOp>()) {
103 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
105 if (
auto divOp = val.getDefiningOp<DivFeltOp>()) {
106 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
108 if (
auto negOp = val.getDefiningOp<NegFeltOp>()) {
109 return memo[val] = getDegree(negOp.getOperand(), memo);
112 llvm_unreachable(
"Unhandled Felt SSA value in degree computation");
129 void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) {
130 assert(afterOp &&
"afterOp must be a valid Operation*");
132 for (
auto &
use : llvm::make_early_inc_range(oldVal.getUses())) {
133 Operation *user =
use.getOwner();
138 if ((user->getBlock() == afterOp->getBlock()) &&
139 (user->isBeforeInBlock(afterOp) || user == afterOp)) {
148 Value lowerExpression(
149 Value val,
unsigned maxDegree, StructDefOp structDef, FuncDefOp constrainFunc,
150 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
151 SmallVector<AuxAssignment> &auxAssignments
153 if (rewrites.count(val)) {
154 return rewrites[val];
157 unsigned degree = getDegree(val, degreeMemo);
158 if (degree <= maxDegree) {
163 if (
auto mulOp = val.getDefiningOp<MulFeltOp>()) {
165 Value lhs = lowerExpression(
166 mulOp.getLhs(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
168 Value rhs = lowerExpression(
169 mulOp.getRhs(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
172 unsigned lhsDeg = getDegree(lhs, degreeMemo);
173 unsigned rhsDeg = getDegree(rhs, degreeMemo);
175 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
176 Value selfVal = constrainFunc.getArgument(0);
177 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
179 if (lhs == rhs && eraseMul) {
181 addAuxField(structDef, auxName);
183 auto auxVal = builder.create<FieldReadOp>(
184 lhs.getLoc(), lhs.getType(), selfVal, builder.getStringAttr(auxName)
186 auxAssignments.push_back({auxName, lhs});
187 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
188 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, lhs);
191 degreeMemo[auxVal] = 1;
192 rewrites[lhs] = auxVal;
193 rewrites[rhs] = auxVal;
195 replaceSubsequentUsesWith(lhs, auxVal, eqOp);
204 while (lhsDeg + rhsDeg > maxDegree) {
205 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
209 addAuxField(structDef, auxName);
212 auto auxVal = builder.create<FieldReadOp>(
213 toFactor.getLoc(), toFactor.getType(), selfVal, builder.getStringAttr(auxName)
217 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
218 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, toFactor);
219 auxAssignments.push_back({auxName, toFactor});
221 rewrites[toFactor] = auxVal;
222 degreeMemo[auxVal] = 1;
224 replaceSubsequentUsesWith(toFactor, auxVal, eqOp);
230 lhsDeg = getDegree(lhs, degreeMemo);
231 rhsDeg = getDegree(rhs, degreeMemo);
235 auto mulVal = builder.
create<MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
237 mulOp->replaceAllUsesWith(mulVal);
242 degreeMemo[mulVal] = lhsDeg + rhsDeg;
243 rewrites[val] = mulVal;
253 Value getSelfValueFromCompute(FuncDefOp computeFunc) {
255 Region &body = computeFunc.
getBody();
256 assert(!body.empty() &&
"compute() function body is empty");
258 Block &block = body.front();
261 Operation *terminator = block.getTerminator();
262 assert(terminator &&
"compute() function has no terminator");
265 auto retOp = dyn_cast<ReturnOp>(terminator);
267 llvm::errs() <<
"Expected ReturnOp as terminator in compute() but found: "
268 << terminator->getName() <<
"\n";
269 llvm_unreachable(
"compute() function terminator is not a ReturnOp");
273 return retOp.getOperands().front();
276 Value rebuildExprInCompute(
277 Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap<Value, Value> &rebuildMemo
280 if (
auto it = rebuildMemo.find(val); it != rebuildMemo.end()) {
285 if (
auto barg = val.dyn_cast<BlockArgument>()) {
286 unsigned index = barg.getArgNumber();
287 Value computeArg = computeFunc.getArgument(index - 1);
288 rebuildMemo[val] = computeArg;
293 if (
auto readOp = val.getDefiningOp<FieldReadOp>()) {
294 Value selfVal = getSelfValueFromCompute(computeFunc);
295 auto rebuiltRead = builder.create<FieldReadOp>(
296 readOp.getLoc(), readOp.getType(), selfVal, readOp.getFieldNameAttr().getAttr()
298 rebuildMemo[val] = rebuiltRead.getResult();
299 return rebuiltRead.getResult();
303 if (
auto addOp = val.getDefiningOp<AddFeltOp>()) {
304 Value lhs = rebuildExprInCompute(addOp.getLhs(), computeFunc, builder, rebuildMemo);
305 Value rhs = rebuildExprInCompute(addOp.getRhs(), computeFunc, builder, rebuildMemo);
306 auto rebuiltAdd = builder.create<AddFeltOp>(addOp.getLoc(), addOp.getType(), lhs, rhs);
307 rebuildMemo[val] = rebuiltAdd.getResult();
308 return rebuiltAdd.getResult();
312 if (
auto subOp = val.getDefiningOp<SubFeltOp>()) {
313 Value lhs = rebuildExprInCompute(subOp.getLhs(), computeFunc, builder, rebuildMemo);
314 Value rhs = rebuildExprInCompute(subOp.getRhs(), computeFunc, builder, rebuildMemo);
315 auto rebuiltSub = builder.create<SubFeltOp>(subOp.getLoc(), subOp.getType(), lhs, rhs);
316 rebuildMemo[val] = rebuiltSub.getResult();
317 return rebuiltSub.getResult();
321 if (
auto mulOp = val.getDefiningOp<MulFeltOp>()) {
322 Value lhs = rebuildExprInCompute(mulOp.getLhs(), computeFunc, builder, rebuildMemo);
323 Value rhs = rebuildExprInCompute(mulOp.getRhs(), computeFunc, builder, rebuildMemo);
324 auto rebuiltMul = builder.create<MulFeltOp>(mulOp.getLoc(), mulOp.getType(), lhs, rhs);
325 rebuildMemo[val] = rebuiltMul.getResult();
326 return rebuiltMul.getResult();
330 if (
auto negOp = val.getDefiningOp<NegFeltOp>()) {
331 Value inner = rebuildExprInCompute(negOp.getOperand(), computeFunc, builder, rebuildMemo);
332 auto rebuiltNeg = builder.create<NegFeltOp>(negOp.getLoc(), negOp.getType(), inner);
333 rebuildMemo[val] = rebuiltNeg.getResult();
334 return rebuiltNeg.getResult();
338 if (
auto divOp = val.getDefiningOp<DivFeltOp>()) {
339 Value lhs = rebuildExprInCompute(divOp.getLhs(), computeFunc, builder, rebuildMemo);
340 Value rhs = rebuildExprInCompute(divOp.getRhs(), computeFunc, builder, rebuildMemo);
341 auto rebuiltDiv = builder.create<DivFeltOp>(divOp.getLoc(), divOp.getType(), lhs, rhs);
342 rebuildMemo[val] = rebuiltDiv.getResult();
343 return rebuiltDiv.getResult();
347 if (
auto constOp = val.getDefiningOp<FeltConstantOp>()) {
348 auto newConst = builder.create<FeltConstantOp>(constOp.getLoc(), constOp.getValue());
349 rebuildMemo[val] = newConst.getResult();
350 return newConst.getResult();
353 llvm::errs() <<
"Unhandled expression kind in rebuildExprInCompute: " << val <<
"\n";
354 llvm_unreachable(
"Unsupported op in rebuildExprInCompute");
359 void checkForAuxFieldConflicts(StructDefOp structDef) {
360 structDef.walk([&](FieldDefOp fieldDefOp) {
362 fieldDefOp.emitError() <<
"Field name: \"" << fieldDefOp.getName()
364 <<
"\" which is reserved for lowering pass";
371 void runOnOperation()
override {
372 ModuleOp moduleOp = getOperation();
376 moduleOp.emitError() <<
"Invalid max degree: " << maxDegree.getValue() <<
". Must be >= 2.";
381 moduleOp.walk([&](StructDefOp structDef) {
384 if (!constrainFunc) {
385 structDef.emitOpError() <<
"\"" << structDef.getName() <<
"\" doesn't have a '"
392 structDef.emitOpError() <<
"\"" << structDef.getName() <<
"\" doesn't have a '"
398 checkForAuxFieldConflicts(structDef);
400 DenseMap<Value, unsigned> degreeMemo;
401 DenseMap<Value, Value> rewrites;
402 SmallVector<AuxAssignment> auxAssignments;
405 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
408 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
409 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
411 if (degreeLhs > maxDegree) {
412 Value loweredExpr = lowerExpression(
413 lhsOperand.get(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites,
416 lhsOperand.set(loweredExpr);
418 if (degreeRhs > maxDegree) {
419 Value loweredExpr = lowerExpression(
420 rhsOperand.get(), maxDegree, structDef, constrainFunc, degreeMemo, rewrites,
423 rhsOperand.set(loweredExpr);
430 constrainFunc.walk([&](EmitContainmentOp containOp) {
431 moduleOp.emitError() <<
"EmitContainmentOp is unsupported for now in the lowering pass";
437 constrainFunc.walk([&](CallOp callOp) {
439 SmallVector<Value> newOperands = llvm::to_vector(callOp.
getArgOperands());
440 bool modified =
false;
442 for (Value &arg : newOperands) {
443 unsigned deg = getDegree(arg, degreeMemo);
446 Value loweredArg = lowerExpression(
447 arg, maxDegree, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
455 SmallVector<ValueRange> mapOperands;
456 OpBuilder builder(callOp);
458 mapOperands.push_back(group);
461 builder.create<CallOp>(
462 callOp.getLoc(), callOp.getResultTypes(), callOp.
getCallee(), mapOperands,
470 DenseMap<Value, Value> rebuildMemo;
471 Block &computeBlock = computeFunc.
getBody().front();
472 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
473 Value selfVal = getSelfValueFromCompute(computeFunc);
475 for (
const auto &assign : auxAssignments) {
477 rebuildExprInCompute(assign.computedValue, computeFunc, builder, rebuildMemo);
478 builder.create<FieldWriteOp>(
479 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName),
489 return std::make_unique<PolyLoweringPass>();
493 auto pass = std::make_unique<PolyLoweringPass>();
494 static_cast<PolyLoweringPass *
>(pass.get())->setMaxDegree(maxDegree);
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
#define AUXILIARY_FIELD_PREFIX
::mlir::Region & getBody()
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
::mlir::OpOperand & getRhsMutable()
::mlir::OpOperand & getLhsMutable()
::mlir::OperandRangeRange getMapOperands()
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::mlir::Operation::operand_range getArgOperands()
::mlir::Region & getBody()
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
constexpr char FUNC_NAME_CONSTRAIN[]
std::unique_ptr< mlir::Pass > createPolyLoweringPass()