22#include "r1cs/Dialect/IR/Attrs.h"
23#include "r1cs/Dialect/IR/Ops.h"
24#include "r1cs/Dialect/IR/Types.h"
26#include <mlir/IR/BuiltinOps.h>
28#include <llvm/ADT/DenseMap.h>
29#include <llvm/ADT/DenseMapInfo.h>
30#include <llvm/ADT/SmallVector.h>
31#include <llvm/Support/Debug.h>
38#define GEN_PASS_DEF_R1CSLOWERINGPASS
49#define DEBUG_TYPE "llzk-r1cs-lowering"
50#define R1CS_AUXILIARY_FIELD_PREFIX "__llzk_r1cs_lowering_pass_aux_field_"
55struct LinearCombination {
56 DenseMap<Value, DynamicAPInt> terms;
57 DynamicAPInt constant;
59 LinearCombination() : constant() {}
61 void addTerm(Value v,
const DynamicAPInt &coeff) {
66 if (!terms.contains(v)) {
73 void addTerm(Value v, int64_t coeff) {
74 DynamicAPInt dynamicCoeff(coeff);
75 return addTerm(v, dynamicCoeff);
79 for (
auto &kv : terms) {
80 kv.second = -kv.second;
85 LinearCombination scaled(
const DynamicAPInt &factor)
const {
86 LinearCombination result;
91 for (
const auto &kv : terms) {
92 result.terms[kv.first] = kv.second * factor;
94 result.constant = constant * factor;
98 LinearCombination scaled(int64_t factor)
const {
99 DynamicAPInt dynamicFactor(factor);
100 return scaled(dynamicFactor);
103 LinearCombination
add(
const LinearCombination &other)
const {
104 LinearCombination result(*
this);
106 for (
const auto &kv : other.terms) {
107 if (!result.terms.contains(kv.first)) {
108 result.terms[kv.first] = kv.second;
110 result.terms[kv.first] = result.terms[kv.first] + kv.second;
113 result.constant = result.constant + other.constant;
117 LinearCombination negated()
const {
return scaled(-1); }
119 void print(raw_ostream &os)
const {
121 for (
const auto &[val, coeff] : terms) {
126 os << coeff <<
'*' << val;
134 if (first && constant == 0) {
141struct R1CSConstraint {
146 R1CSConstraint negated()
const {
147 R1CSConstraint result(*
this);
148 result.a = a.negated();
149 result.c = c.negated();
153 R1CSConstraint scaled(
const DynamicAPInt &factor)
const {
154 R1CSConstraint result(*
this);
155 result.a = a.scaled(factor);
156 result.c = c.scaled(factor);
160 R1CSConstraint(
const DynamicAPInt &constant) { c.constant = constant; }
162 R1CSConstraint() =
default;
164 inline bool isLinearOnly()
const {
return a.terms.empty() && b.terms.empty(); }
166 R1CSConstraint multiply(
const R1CSConstraint &other) {
167 auto isDegZero = [](
const R1CSConstraint &constraint) {
168 return constraint.a.terms.empty() && constraint.b.terms.empty() && constraint.c.terms.empty();
171 if (isDegZero(other)) {
172 return this->scaled(other.c.constant);
174 if (isDegZero(*
this)) {
175 return other.scaled(this->c.constant);
178 if (isLinearOnly() && other.isLinearOnly()) {
179 R1CSConstraint result;
188 llvm::errs() <<
"R1CSConstraint::multiply: Only supported for purely linear constraints.\n";
189 llvm_unreachable(
"Invalid multiply: non-linear constraint(s) involved");
192 R1CSConstraint
add(
const R1CSConstraint &other) {
194 if (isLinearOnly()) {
195 R1CSConstraint result(other);
196 result.c = result.c.add(this->c);
199 if (other.isLinearOnly()) {
200 R1CSConstraint result(*
this);
201 result.c = result.c.add(other.c);
204 llvm::errs() <<
"R1CSConstraint::add: Only supported for purely linear constraints.\n";
205 llvm_unreachable(
"Invalid add: non-linear constraint(s) involved");
208 void print(raw_ostream &os)
const {
220 unsigned auxCounter = 0;
230 void getPostOrder(Value root, SmallVectorImpl<Value> &postOrder) {
231 SmallVector<Value, 16> worklist;
232 DenseSet<Value> visited;
234 worklist.push_back(root);
236 while (!worklist.empty()) {
237 Value val = worklist.back();
239 if (visited.contains(val)) {
241 postOrder.push_back(val);
246 if (Operation *op = val.getDefiningOp()) {
247 for (Value operand : op->getOperands()) {
248 worklist.push_back(operand);
278 Value normalizeForR1CS(
280 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
281 SmallVectorImpl<AuxAssignment> &auxAssignments, OpBuilder &builder
283 if (
auto it = rewrites.find(root); it != rewrites.end()) {
287 SmallVector<Value, 16> postOrder;
288 getPostOrder(root, postOrder);
292 for (Value val : postOrder) {
293 if (rewrites.contains(val)) {
297 Operation *op = val.getDefiningOp();
307 if (
auto c = llvm::dyn_cast<FeltConstantOp>(op)) {
314 if (
auto fr = llvm::dyn_cast<FieldReadOp>(op)) {
321 auto getDeg = [°reeMemo](Value v) ->
unsigned {
322 auto it = degreeMemo.find(v);
323 assert(it != degreeMemo.end() &&
"Missing degree");
331 auto handleAddOrSub = [&](Value lhsOrig, Value rhsOrig,
bool isAdd) {
332 Value lhs = rewrites[lhsOrig];
333 Value rhs = rewrites[rhsOrig];
334 unsigned degLhs = getDeg(lhs);
335 unsigned degRhs = getDeg(rhs);
337 if (degLhs == 2 && degRhs == 2) {
338 builder.setInsertionPoint(op);
340 FieldDefOp auxField =
addAuxField(structDef, auxName);
341 Value aux = builder.create<FieldReadOp>(
343 auxField.getNameAttr()
345 auto eqOp = builder.create<EmitEqualityOp>(val.getLoc(), aux, lhs);
346 auxAssignments.push_back({auxName, lhs});
353 Operation *newOp = isAdd
354 ? builder.create<AddFeltOp>(val.getLoc(), val.getType(), lhs, rhs)
355 : builder.
create<SubFeltOp>(val.getLoc(), val.getType(), lhs, rhs);
356 Value result = newOp->getResult(0);
357 degreeMemo[result] = std::max(degLhs, degRhs);
358 rewrites[val] = result;
359 rewrites[result] = result;
360 val.replaceAllUsesWith(result);
361 if (val.use_empty()) {
365 degreeMemo[val] = std::max(degLhs, degRhs);
370 if (
auto add = llvm::dyn_cast<AddFeltOp>(op)) {
371 handleAddOrSub(
add.getLhs(),
add.getRhs(),
true);
375 if (
auto sub = llvm::dyn_cast<SubFeltOp>(op)) {
376 handleAddOrSub(
sub.getLhs(),
sub.getRhs(),
false);
383 if (
auto mul = llvm::dyn_cast<MulFeltOp>(op)) {
384 Value lhs = rewrites[
mul.getLhs()];
385 Value rhs = rewrites[
mul.getRhs()];
386 unsigned degLhs = getDeg(lhs);
387 unsigned degRhs = getDeg(rhs);
389 degreeMemo[val] = degLhs + degRhs;
396 if (
auto neg = llvm::dyn_cast<NegFeltOp>(op)) {
397 Value inner = rewrites[
neg.getOperand()];
398 unsigned deg = getDeg(inner);
399 degreeMemo[val] = deg;
404 llvm::errs() <<
"Unhandled op in normalize ForR1CS: " << *op <<
'\n';
408 return rewrites[root];
411 R1CSConstraint lowerPolyToR1CS(Value poly) {
413 SmallVector<Value, 16> worklist = {poly};
414 DenseMap<Value, R1CSConstraint> constraintMap;
415 DenseSet<Value> visited;
416 SmallVector<Value, 16> postorder;
418 getPostOrder(poly, postorder);
421 for (Value v : postorder) {
422 Operation *op = v.getDefiningOp();
423 if (!op || llvm::isa<FieldReadOp>(op)) {
427 constraintMap[v] = eq;
430 if (
auto add = dyn_cast<AddFeltOp>(op)) {
431 R1CSConstraint lhsC = constraintMap[
add.getLhs()];
432 R1CSConstraint rhsC = constraintMap[
add.getRhs()];
433 constraintMap[v] = lhsC.add(rhsC);
434 }
else if (
auto sub = dyn_cast<SubFeltOp>(op)) {
435 R1CSConstraint lhsC = constraintMap[
sub.getLhs()];
436 R1CSConstraint rhsC = constraintMap[
sub.getRhs()];
437 constraintMap[v] = lhsC.add(rhsC.negated());
438 }
else if (
auto mul = dyn_cast<MulFeltOp>(op)) {
439 R1CSConstraint lhsC = constraintMap[
mul.getLhs()];
440 R1CSConstraint rhsC = constraintMap[
mul.getRhs()];
441 constraintMap[v] = lhsC.multiply(rhsC);
442 }
else if (
auto neg = dyn_cast<NegFeltOp>(op)) {
443 R1CSConstraint inner = constraintMap[op->getOperand(0)];
444 constraintMap[v] = inner.negated();
445 }
else if (
auto cst = dyn_cast<FeltConstantOp>(op)) {
447 constraintMap[v] = c;
449 llvm::errs() <<
"Unhandled op in R1CS lowering: " << *op <<
'\n';
450 llvm_unreachable(
"unhandled op");
454 return constraintMap[poly];
458 lowerEquationToR1CS(Value p, Value q,
const DenseMap<Value, unsigned> °reeMemo) {
459 R1CSConstraint pconst = lowerPolyToR1CS(p);
460 R1CSConstraint qconst = lowerPolyToR1CS(q);
462 if (degreeMemo.at(p) == 2) {
463 return pconst.add(qconst.negated());
465 return qconst.add(pconst.negated());
468 Value emitLinearCombination(
469 LinearCombination lc, IRMapping &valueMap, DenseMap<StringRef, Value> &fieldMap,
470 OpBuilder &builder, Location loc
472 Value result =
nullptr;
474 auto getMapping = [&valueMap, &fieldMap,
this](
const Value &v) {
475 if (!valueMap.contains(v)) {
476 Operation *op = v.getDefiningOp();
477 if (
auto read = dyn_cast<FieldReadOp>(op)) {
478 auto fieldVal = fieldMap.find(read.getFieldName());
479 assert(fieldVal != fieldMap.end() &&
"Field read not associated with a value");
480 return fieldVal->second;
482 op->emitError(
"Value not mapped in R1CS lowering").report();
485 return valueMap.lookup(v);
488 auto linearTy = r1cs::LinearType::get(builder.getContext());
491 if (lc.constant != 0) {
492 result = builder.create<r1cs::ConstOp>(
493 loc, linearTy, r1cs::FeltAttr::get(builder.getContext(),
toAPSInt(lc.constant))
497 for (
auto &[val, coeff] : lc.terms) {
498 Value mapped = getMapping(val);
501 Value lin = builder.create<r1cs::ToLinearOp>(loc, linearTy, mapped);
503 Value scaled = coeff == 1 ? lin
504 : builder.create<r1cs::MulConstOp>(
506 r1cs::FeltAttr::get(builder.getContext(),
toAPSInt(coeff))
513 result = builder.create<r1cs::AddOp>(loc, linearTy, result, scaled);
519 result = builder.create<r1cs::ConstOp>(
520 loc, r1cs::LinearType::get(builder.getContext()),
521 r1cs::FeltAttr::get(builder.getContext(), 0)
528 void buildAndEmitR1CS(
529 ModuleOp &moduleOp, StructDefOp &structDef, FuncDefOp &constrainFunc,
530 DenseMap<Value, unsigned> °reeMemo
532 SmallVector<R1CSConstraint, 16> constraints;
533 constrainFunc.walk([&](EmitEqualityOp eqOp) {
534 OpBuilder builder(eqOp);
535 R1CSConstraint eq = lowerEquationToR1CS(eqOp.
getLhs(), eqOp.
getRhs(), degreeMemo);
536 constraints.push_back(eq);
538 moduleOp->setAttr(
LANG_ATTR_NAME, StringAttr::get(moduleOp.getContext(),
"r1cs"));
539 Block &entryBlock = constrainFunc.
getBody().front();
541 Location loc = structDef.getLoc();
542 OpBuilder topBuilder(moduleOp.getBodyRegion());
545 bool hasPublicSignals =
false;
547 if (!llvm::isa<FeltType>(field.getType())) {
548 field.emitError(
"Only felt fields are supported as output signals").report();
552 if (field.isPublic()) {
553 hasPublicSignals =
true;
557 if (!hasPublicSignals) {
558 structDef.emitError(
"Struct should have at least one public output").report();
560 llvm::SmallVector<mlir::NamedAttribute> argAttrPairs;
562 for (
auto [i, arg] : llvm::enumerate(llvm::drop_begin(entryBlock.getArguments(), 1))) {
564 auto key = topBuilder.getStringAttr(std::to_string(i));
565 auto value = r1cs::PublicAttr::get(moduleOp.getContext());
566 argAttrPairs.emplace_back(key, value);
569 auto dictAttr = topBuilder.getDictionaryAttr(argAttrPairs);
571 topBuilder.create<r1cs::CircuitDefOp>(loc, structDef.
getSymName().str(), dictAttr);
573 Block *circuitBlock = circuit.addEntryBlock();
575 OpBuilder bodyBuilder = OpBuilder::atBlockEnd(circuitBlock);
578 for (
auto [i, arg] : llvm::enumerate(llvm::drop_begin(entryBlock.getArguments(), 1))) {
579 if (!llvm::isa<FeltType>(arg.getType())) {
580 constrainFunc.emitOpError(
"All input arguments must be of felt type").report();
584 auto blockArg = circuitBlock->addArgument(bodyBuilder.getType<r1cs::SignalType>(), loc);
585 valueMap.map(arg, blockArg);
590 DenseMap<StringRef, Value> fieldSignalMap;
591 uint32_t signalDefCntr = 0;
593 r1cs::PublicAttr pubAttr;
594 if (field.hasPublicAttr()) {
595 pubAttr = bodyBuilder.getAttr<r1cs::PublicAttr>();
597 auto defOp = bodyBuilder.create<r1cs::SignalDefOp>(
598 field.getLoc(), bodyBuilder.getType<r1cs::SignalType>(),
599 bodyBuilder.getUI32IntegerAttr(signalDefCntr), pubAttr
602 fieldSignalMap.insert({field.getName(), defOp.getOut()});
604 DenseMap<std::tuple<Value, Value, StringRef>, Value> binaryOpCache;
606 for (
const R1CSConstraint &constraint : constraints) {
607 Value aVal = emitLinearCombination(constraint.a, valueMap, fieldSignalMap, bodyBuilder, loc);
608 Value bVal = emitLinearCombination(constraint.b, valueMap, fieldSignalMap, bodyBuilder, loc);
609 Value cVal = emitLinearCombination(constraint.c, valueMap, fieldSignalMap, bodyBuilder, loc);
610 bodyBuilder.create<r1cs::ConstrainOp>(loc, aVal, bVal, cVal);
614 void getDependentDialects(mlir::DialectRegistry ®istry)
const override {
615 registry.insert<r1cs::R1CSDialect>();
618 void runOnOperation()
override {
619 ModuleOp moduleOp = getOperation();
621 moduleOp->getContext()->getLoadedDialect<r1cs::R1CSDialect>() &&
"R1CS dialect not loaded"
623 moduleOp.walk([
this, &moduleOp](StructDefOp structDef) {
626 if (!constrainFunc || !computeFunc) {
627 structDef.emitOpError(
"Missing compute or constrain function").report();
637 DenseMap<Value, unsigned> degreeMemo;
638 DenseMap<Value, Value> rewrites;
639 SmallVector<AuxAssignment> auxAssignments;
641 constrainFunc.walk([&](EmitEqualityOp eqOp) {
642 OpBuilder builder(eqOp);
643 Value lhs = normalizeForR1CS(
644 eqOp.
getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments, builder
646 Value rhs = normalizeForR1CS(
647 eqOp.
getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments, builder
650 unsigned degLhs = degreeMemo.lookup(lhs);
651 unsigned degRhs = degreeMemo.lookup(rhs);
654 if (degLhs == 2 && degRhs == 2) {
655 builder.setInsertionPoint(eqOp);
657 FieldDefOp auxField =
addAuxField(structDef, auxName);
658 Value aux = builder.create<FieldReadOp>(
660 auxField.getNameAttr()
662 auto eqAux = builder.create<EmitEqualityOp>(eqOp.getLoc(), aux, lhs);
663 auxAssignments.push_back({auxName, lhs});
669 builder.create<EmitEqualityOp>(eqOp.getLoc(), lhs, rhs);
673 Block &computeBlock = computeFunc.
getBody().front();
674 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
676 DenseMap<Value, Value> rebuildMemo;
678 for (
const auto &assign : auxAssignments) {
680 builder.create<FieldWriteOp>(
681 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName), expr
684 buildAndEmitR1CS(moduleOp, structDef, constrainFunc, degreeMemo);
692 return std::make_unique<R1CSLoweringPass>();
This file implements helper methods for constructing DynamicAPInts.
#define R1CS_AUXILIARY_FIELD_PREFIX
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
::mlir::TypedValue<::mlir::Type > getLhs()
::mlir::TypedValue<::mlir::Type > getRhs()
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
::mlir::Region & getBody()
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the IR language name.
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
DynamicAPInt toDynamicAPInt(StringRef str)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
APSInt toAPSInt(const DynamicAPInt &i)
std::unique_ptr< mlir::Pass > createR1CSLoweringPass()
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.