22#include <mlir/IR/BuiltinOps.h>
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
34#define GEN_PASS_DECL_POLYLOWERINGPASS
35#define GEN_PASS_DEF_POLYLOWERINGPASS
46#define DEBUG_TYPE "llzk-poly-lowering-pass"
47#define AUXILIARY_FIELD_PREFIX "__llzk_poly_lowering_pass_aux_field_"
52 std::string auxFieldName;
58 void setMaxDegree(
unsigned degree) { this->maxDegree = degree; }
61 unsigned auxCounter = 0;
63 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
65 structDefs.push_back(structDef);
66 return WalkResult::skip();
71 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
72 if (
auto it = memo.find(val); it != memo.end()) {
76 if (val.isa<BlockArgument>()) {
89 if (
auto addOp = val.getDefiningOp<
AddFeltOp>()) {
90 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
92 if (
auto subOp = val.getDefiningOp<
SubFeltOp>()) {
93 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
95 if (
auto mulOp = val.getDefiningOp<MulFeltOp>()) {
96 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
98 if (
auto divOp = val.getDefiningOp<DivFeltOp>()) {
99 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
101 if (
auto negOp = val.getDefiningOp<NegFeltOp>()) {
102 return memo[val] = getDegree(negOp.getOperand(), memo);
105 llvm_unreachable(
"Unhandled Felt SSA value in degree computation");
108 Value lowerExpression(
109 Value val, StructDefOp structDef, FuncDefOp constrainFunc,
110 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
111 SmallVector<AuxAssignment> &auxAssignments
113 if (rewrites.count(val)) {
114 return rewrites[val];
117 unsigned degree = getDegree(val, degreeMemo);
118 if (degree <= maxDegree) {
123 if (
auto mulOp = val.getDefiningOp<MulFeltOp>()) {
125 Value lhs = lowerExpression(
126 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
128 Value rhs = lowerExpression(
129 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
132 unsigned lhsDeg = getDegree(lhs, degreeMemo);
133 unsigned rhsDeg = getDegree(rhs, degreeMemo);
135 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
136 Value selfVal = constrainFunc.getArgument(0);
137 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
139 if (lhs == rhs && eraseMul) {
141 FieldDefOp auxField =
addAuxField(structDef, auxName);
143 auto auxVal = builder.create<FieldReadOp>(
144 lhs.getLoc(), lhs.getType(), selfVal, auxField.getNameAttr()
146 auxAssignments.push_back({auxName, lhs});
147 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
148 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, lhs);
151 degreeMemo[auxVal] = 1;
152 rewrites[lhs] = auxVal;
153 rewrites[rhs] = auxVal;
164 while (lhsDeg + rhsDeg > maxDegree) {
165 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
169 FieldDefOp auxField =
addAuxField(structDef, auxName);
172 auto auxVal = builder.create<FieldReadOp>(
173 toFactor.getLoc(), toFactor.getType(), selfVal, auxField.getNameAttr()
177 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
178 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, toFactor);
179 auxAssignments.push_back({auxName, toFactor});
181 rewrites[toFactor] = auxVal;
182 degreeMemo[auxVal] = 1;
190 lhsDeg = getDegree(lhs, degreeMemo);
191 rhsDeg = getDegree(rhs, degreeMemo);
195 auto mulVal = builder.
create<MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
197 mulOp->replaceAllUsesWith(mulVal);
202 degreeMemo[mulVal] = lhsDeg + rhsDeg;
203 rewrites[val] = mulVal;
213 void runOnOperation()
override {
214 ModuleOp moduleOp = getOperation();
218 auto diag = moduleOp.emitError();
219 diag <<
"Invalid max degree: " << maxDegree.getValue() <<
". Must be >= 2.";
225 moduleOp.walk([
this, &moduleOp](StructDefOp structDef) {
228 if (!constrainFunc) {
229 auto diag = structDef.emitOpError();
238 auto diag = structDef.emitOpError();
239 diag <<
'"' << structDef.getName() <<
"\" doesn't have a \"@" <<
FUNC_NAME_COMPUTE
251 DenseMap<Value, unsigned> degreeMemo;
252 DenseMap<Value, Value> rewrites;
253 SmallVector<AuxAssignment> auxAssignments;
256 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
259 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
260 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
262 if (degreeLhs > maxDegree) {
263 Value loweredExpr = lowerExpression(
264 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
266 lhsOperand.set(loweredExpr);
268 if (degreeRhs > maxDegree) {
269 Value loweredExpr = lowerExpression(
270 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
272 rhsOperand.set(loweredExpr);
279 constrainFunc.walk([
this, &moduleOp](EmitContainmentOp containOp) {
280 auto diag = moduleOp.emitError();
281 diag <<
"EmitContainmentOp is unsupported for now in the lowering pass";
288 constrainFunc.walk([&](CallOp callOp) {
290 SmallVector<Value> newOperands = llvm::to_vector(callOp.
getArgOperands());
291 bool modified =
false;
293 for (Value &arg : newOperands) {
294 unsigned deg = getDegree(arg, degreeMemo);
297 Value loweredArg = lowerExpression(
298 arg, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
306 SmallVector<ValueRange> mapOperands;
307 OpBuilder builder(callOp);
309 mapOperands.push_back(group);
312 builder.create<CallOp>(
313 callOp.getLoc(), callOp.getResultTypes(), callOp.
getCallee(), mapOperands,
321 DenseMap<Value, Value> rebuildMemo;
322 Block &computeBlock = computeFunc.
getBody().front();
323 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
326 for (
const auto &assign : auxAssignments) {
329 builder.create<FieldWriteOp>(
330 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName),
340 return std::make_unique<PolyLoweringPass>();
344 auto pass = std::make_unique<PolyLoweringPass>();
345 static_cast<PolyLoweringPass *
>(pass.get())->setMaxDegree(maxDegree);
#define AUXILIARY_FIELD_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
::mlir::OpOperand & getRhsMutable()
::mlir::OpOperand & getLhsMutable()
::mlir::OperandRangeRange getMapOperands()
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::mlir::Operation::operand_range getArgOperands()
::mlir::Region & getBody()
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
Value getSelfValueFromCompute(FuncDefOp computeFunc)
constexpr char FUNC_NAME_CONSTRAIN[]
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
std::unique_ptr< mlir::Pass > createPolyLoweringPass()