22#include <mlir/IR/BuiltinOps.h>
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
34#define GEN_PASS_DECL_POLYLOWERINGPASS
35#define GEN_PASS_DEF_POLYLOWERINGPASS
46#define DEBUG_TYPE "llzk-poly-lowering-pass"
47#define AUXILIARY_FIELD_PREFIX "__llzk_poly_lowering_pass_aux_field_"
52 std::string auxFieldName;
58 void setMaxDegree(
unsigned degree) { this->maxDegree = degree; }
61 unsigned auxCounter = 0;
63 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
65 structDefs.push_back(structDef);
66 return WalkResult::skip();
71 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
72 if (
auto it = memo.find(val); it != memo.end()) {
76 if (val.isa<BlockArgument>()) {
89 if (
auto addOp = val.getDefiningOp<
AddFeltOp>()) {
90 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
92 if (
auto subOp = val.getDefiningOp<
SubFeltOp>()) {
93 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
95 if (
auto mulOp = val.getDefiningOp<
MulFeltOp>()) {
96 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
99 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
101 if (
auto negOp = val.getDefiningOp<
NegFeltOp>()) {
102 return memo[val] = getDegree(negOp.getOperand(), memo);
105 llvm_unreachable(
"Unhandled Felt SSA value in degree computation");
108 Value lowerExpression(
110 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
111 SmallVector<AuxAssignment> &auxAssignments
113 if (rewrites.count(val)) {
114 return rewrites[val];
117 unsigned degree = getDegree(val, degreeMemo);
123 if (
auto mulOp = val.getDefiningOp<
MulFeltOp>()) {
125 Value lhs = lowerExpression(
126 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
128 Value rhs = lowerExpression(
129 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
132 unsigned lhsDeg = getDegree(lhs, degreeMemo);
133 unsigned rhsDeg = getDegree(rhs, degreeMemo);
135 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
137 bool eraseMul = lhsDeg + rhsDeg >
maxDegree;
139 if (lhs == rhs && eraseMul) {
144 lhs.getLoc(), lhs.getType(), selfVal, auxField.getNameAttr()
146 auxAssignments.push_back({auxName, lhs});
147 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
151 degreeMemo[auxVal] = 1;
152 rewrites[lhs] = auxVal;
153 rewrites[rhs] = auxVal;
165 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
173 toFactor.getLoc(), toFactor.getType(), selfVal, auxField.getNameAttr()
177 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
178 auto eqOp = builder.create<
EmitEqualityOp>(loc, auxVal, toFactor);
179 auxAssignments.push_back({auxName, toFactor});
181 rewrites[toFactor] = auxVal;
182 degreeMemo[auxVal] = 1;
190 lhsDeg = getDegree(lhs, degreeMemo);
191 rhsDeg = getDegree(rhs, degreeMemo);
195 auto mulVal = builder.create<
MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
197 mulOp->replaceAllUsesWith(mulVal);
202 degreeMemo[mulVal] = lhsDeg + rhsDeg;
203 rewrites[val] = mulVal;
213 void runOnOperation()
override {
214 ModuleOp moduleOp = getOperation();
218 auto diag = moduleOp.emitError();
219 diag <<
"Invalid max degree: " << maxDegree.getValue() <<
". Must be >= 2.";
225 moduleOp.walk([
this, &moduleOp](StructDefOp structDef) {
228 if (!constrainFunc) {
229 auto diag = structDef.emitOpError();
238 auto diag = structDef.emitOpError();
239 diag <<
'"' << structDef.getName() <<
"\" doesn't have a \"@" <<
FUNC_NAME_COMPUTE
251 DenseMap<Value, unsigned> degreeMemo;
252 DenseMap<Value, Value> rewrites;
253 SmallVector<AuxAssignment> auxAssignments;
256 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
259 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
260 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
262 if (degreeLhs > maxDegree) {
263 Value loweredExpr = lowerExpression(
264 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
266 lhsOperand.set(loweredExpr);
268 if (degreeRhs > maxDegree) {
269 Value loweredExpr = lowerExpression(
270 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
272 rhsOperand.set(loweredExpr);
279 constrainFunc.walk([
this, &moduleOp](EmitContainmentOp containOp) {
280 auto diag = moduleOp.emitError();
281 diag <<
"EmitContainmentOp is unsupported for now in the lowering pass";
288 constrainFunc.walk([&](CallOp callOp) {
290 SmallVector<Value> newOperands = llvm::to_vector(callOp.
getArgOperands());
291 bool modified =
false;
293 for (Value &arg : newOperands) {
294 unsigned deg = getDegree(arg, degreeMemo);
297 Value loweredArg = lowerExpression(
298 arg, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
306 OpBuilder builder(callOp);
307 builder.create<CallOp>(
308 callOp.getLoc(), callOp.getResultTypes(), callOp.
getCallee(),
317 DenseMap<Value, Value> rebuildMemo;
318 Block &computeBlock = computeFunc.
getBody().front();
319 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
322 for (
const auto &assign : auxAssignments) {
325 builder.create<FieldWriteOp>(
326 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName),
336 return std::make_unique<PolyLoweringPass>();
340 auto pass = std::make_unique<PolyLoweringPass>();
341 static_cast<PolyLoweringPass *
>(pass.get())->setMaxDegree(maxDegree);
#define AUXILIARY_FIELD_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
::mlir::OpOperand & getRhsMutable()
::mlir::OpOperand & getLhsMutable()
::mlir::OperandRangeRange getMapOperands()
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Operation::operand_range getArgOperands()
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::Region & getBody()
::mlir::Pass::Option< unsigned > maxDegree
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
constexpr char FUNC_NAME_CONSTRAIN[]
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
std::unique_ptr< mlir::Pass > createPolyLoweringPass()