22#include "r1cs/Dialect/IR/Attrs.h"
23#include "r1cs/Dialect/IR/Ops.h"
24#include "r1cs/Dialect/IR/Types.h"
26#include <mlir/IR/BuiltinOps.h>
28#include <llvm/ADT/DenseMap.h>
29#include <llvm/ADT/DenseMapInfo.h>
30#include <llvm/ADT/SmallVector.h>
31#include <llvm/Support/Debug.h>
38#define GEN_PASS_DEF_R1CSLOWERINGPASS
49#define DEBUG_TYPE "llzk-r1cs-lowering"
50#define R1CS_AUXILIARY_FIELD_PREFIX "__llzk_r1cs_lowering_pass_aux_field_"
55struct LinearCombination {
56 DenseMap<Value, APSInt> terms;
59 LinearCombination() : constant(APInt(1, 0),
false) {}
61 void addTerm(Value v, APSInt coeff) {
66 if (!terms.contains(v)) {
74 for (
auto &kv : terms) {
75 kv.second = -kv.second;
80 void addConstant(APSInt c) { constant += c; }
82 LinearCombination scaled(
const APSInt &factor)
const {
83 LinearCombination result;
84 if (factor.isZero()) {
88 for (
const auto &kv : terms) {
89 result.terms[kv.first] =
expandingMul(kv.second, factor);
95 LinearCombination
add(
const LinearCombination &other)
const {
96 LinearCombination result(*
this);
98 for (
const auto &kv : other.terms) {
99 if (!result.terms.contains(kv.first)) {
100 result.terms[kv.first] = kv.second;
102 result.terms[kv.first] =
expandingAdd(result.terms[kv.first], kv.second);
105 result.constant =
expandingAdd(result.constant, other.constant);
109 LinearCombination negated()
const {
110 return scaled(APSInt(APInt(constant.getBitWidth(), -1,
true),
false));
113 void print(raw_ostream &os)
const {
115 for (
const auto &[val, coeff] : terms) {
120 os << coeff <<
'*' << val;
122 if (!constant.isZero()) {
128 if (first && constant.isZero()) {
135struct R1CSConstraint {
140 R1CSConstraint negated()
const {
141 R1CSConstraint result(*
this);
142 result.a = a.negated();
143 result.c = c.negated();
147 R1CSConstraint scaled(
const APSInt &factor)
const {
148 R1CSConstraint result(*
this);
149 result.a = a.scaled(factor);
150 result.c = c.scaled(factor);
154 R1CSConstraint(
const APSInt &constant) { c.constant = constant; }
156 R1CSConstraint() =
default;
158 inline bool isLinearOnly()
const {
return a.terms.empty() && b.terms.empty(); }
160 R1CSConstraint multiply(
const R1CSConstraint &other) {
161 auto isDegZero = [](
const R1CSConstraint &constraint) {
162 return constraint.a.terms.empty() && constraint.b.terms.empty() && constraint.c.terms.empty();
165 if (isDegZero(other)) {
166 return this->scaled(other.c.constant);
168 if (isDegZero(*
this)) {
169 return other.scaled(this->c.constant);
172 if (isLinearOnly() && other.isLinearOnly()) {
173 R1CSConstraint result;
182 llvm::errs() <<
"R1CSConstraint::multiply: Only supported for purely linear constraints.\n";
183 llvm_unreachable(
"Invalid multiply: non-linear constraint(s) involved");
186 R1CSConstraint
add(
const R1CSConstraint &other) {
188 if (isLinearOnly()) {
189 R1CSConstraint result(other);
190 result.c = result.c.add(this->c);
193 if (other.isLinearOnly()) {
194 R1CSConstraint result(*
this);
195 result.c = result.c.add(other.c);
198 llvm::errs() <<
"R1CSConstraint::add: Only supported for purely linear constraints.\n";
199 llvm_unreachable(
"Invalid add: non-linear constraint(s) involved");
202 void print(raw_ostream &os)
const {
214 unsigned auxCounter = 0;
224 void getPostOrder(Value root, SmallVectorImpl<Value> &postOrder) {
225 SmallVector<Value, 16> worklist;
226 DenseSet<Value> visited;
228 worklist.push_back(root);
230 while (!worklist.empty()) {
231 Value val = worklist.back();
233 if (visited.contains(val)) {
235 postOrder.push_back(val);
240 if (Operation *op = val.getDefiningOp()) {
241 for (Value operand : op->getOperands()) {
242 worklist.push_back(operand);
272 Value normalizeForR1CS(
273 Value root, StructDefOp structDef, FuncDefOp constrainFunc,
274 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
275 SmallVectorImpl<AuxAssignment> &auxAssignments, OpBuilder &builder
277 if (
auto it = rewrites.find(root); it != rewrites.end()) {
281 SmallVector<Value, 16> postOrder;
282 getPostOrder(root, postOrder);
286 for (Value val : postOrder) {
287 if (rewrites.contains(val)) {
291 Operation *op = val.getDefiningOp();
301 if (
auto c = llvm::dyn_cast<FeltConstantOp>(op)) {
308 if (
auto fr = llvm::dyn_cast<FieldReadOp>(op)) {
315 auto getDeg = [°reeMemo](Value v) ->
unsigned {
316 auto it = degreeMemo.find(v);
317 assert(it != degreeMemo.end() &&
"Missing degree");
325 auto handleAddOrSub = [&](Value lhsOrig, Value rhsOrig,
bool isAdd) {
326 Value lhs = rewrites[lhsOrig];
327 Value rhs = rewrites[rhsOrig];
328 unsigned degLhs = getDeg(lhs);
329 unsigned degRhs = getDeg(rhs);
331 if (degLhs == 2 && degRhs == 2) {
332 builder.setInsertionPoint(op);
334 FieldDefOp auxField =
addAuxField(structDef, auxName);
335 Value self = constrainFunc.getArgument(0);
336 Value aux = builder.
create<FieldReadOp>(
337 val.getLoc(), val.getType(), self, auxField.getNameAttr()
339 auto eqOp = builder.create<EmitEqualityOp>(val.getLoc(), aux, lhs);
340 auxAssignments.push_back({auxName, lhs});
347 Operation *newOp = isAdd
348 ? builder.create<AddFeltOp>(val.getLoc(), val.getType(), lhs, rhs)
349 : builder.
create<SubFeltOp>(val.getLoc(), val.getType(), lhs, rhs);
350 Value result = newOp->getResult(0);
351 degreeMemo[result] = std::max(degLhs, degRhs);
352 rewrites[val] = result;
353 rewrites[result] = result;
354 val.replaceAllUsesWith(result);
355 if (val.use_empty()) {
359 degreeMemo[val] = std::max(degLhs, degRhs);
364 if (
auto add = llvm::dyn_cast<AddFeltOp>(op)) {
365 handleAddOrSub(
add.getLhs(),
add.getRhs(),
true);
369 if (
auto sub = llvm::dyn_cast<SubFeltOp>(op)) {
370 handleAddOrSub(
sub.getLhs(),
sub.getRhs(),
false);
377 if (
auto mul = llvm::dyn_cast<MulFeltOp>(op)) {
378 Value lhs = rewrites[
mul.getLhs()];
379 Value rhs = rewrites[
mul.getRhs()];
380 unsigned degLhs = getDeg(lhs);
381 unsigned degRhs = getDeg(rhs);
383 degreeMemo[val] = degLhs + degRhs;
390 if (
auto neg = llvm::dyn_cast<NegFeltOp>(op)) {
391 Value inner = rewrites[
neg.getOperand()];
392 unsigned deg = getDeg(inner);
393 degreeMemo[val] = deg;
398 llvm::errs() <<
"Unhandled op in normalize ForR1CS: " << *op <<
'\n';
402 return rewrites[root];
405 R1CSConstraint lowerPolyToR1CS(Value poly) {
407 SmallVector<Value, 16> worklist = {poly};
408 DenseMap<Value, R1CSConstraint> constraintMap;
409 DenseSet<Value> visited;
410 SmallVector<Value, 16> postorder;
412 getPostOrder(poly, postorder);
415 for (Value v : postorder) {
416 Operation *op = v.getDefiningOp();
417 if (!op || llvm::isa<FieldReadOp>(op)) {
420 eq.c.addTerm(v, APSInt::get(1));
421 constraintMap[v] = eq;
424 if (
auto add = dyn_cast<AddFeltOp>(op)) {
425 R1CSConstraint lhsC = constraintMap[
add.getLhs()];
426 R1CSConstraint rhsC = constraintMap[
add.getRhs()];
427 constraintMap[v] = lhsC.add(rhsC);
428 }
else if (
auto sub = dyn_cast<SubFeltOp>(op)) {
429 R1CSConstraint lhsC = constraintMap[
sub.getLhs()];
430 R1CSConstraint rhsC = constraintMap[
sub.getRhs()];
431 constraintMap[v] = lhsC.add(rhsC.negated());
432 }
else if (
auto mul = dyn_cast<MulFeltOp>(op)) {
433 R1CSConstraint lhsC = constraintMap[
mul.getLhs()];
434 R1CSConstraint rhsC = constraintMap[
mul.getRhs()];
435 constraintMap[v] = lhsC.multiply(rhsC);
436 }
else if (
auto neg = dyn_cast<NegFeltOp>(op)) {
437 R1CSConstraint inner = constraintMap[op->getOperand(0)];
438 constraintMap[v] = inner.negated();
439 }
else if (
auto cst = dyn_cast<FeltConstantOp>(op)) {
440 R1CSConstraint c(APSInt(cst.getValueAttr().getValue(),
false));
441 constraintMap[v] = c;
443 llvm::errs() <<
"Unhandled op in R1CS lowering: " << *op <<
'\n';
444 llvm_unreachable(
"unhandled op");
448 return constraintMap[poly];
452 lowerEquationToR1CS(Value p, Value q,
const DenseMap<Value, unsigned> °reeMemo) {
453 R1CSConstraint pconst = lowerPolyToR1CS(p);
454 R1CSConstraint qconst = lowerPolyToR1CS(q);
456 if (degreeMemo.at(p) == 2) {
457 return pconst.add(qconst.negated());
459 return qconst.add(pconst.negated());
462 Value emitLinearCombination(
463 LinearCombination lc, IRMapping &valueMap, DenseMap<StringRef, Value> &fieldMap,
464 OpBuilder &builder, Location loc
466 Value result =
nullptr;
468 auto getMapping = [&valueMap, &fieldMap,
this](
const Value &v) {
469 if (!valueMap.contains(v)) {
470 Operation *op = v.getDefiningOp();
471 if (
auto read = dyn_cast<FieldReadOp>(op)) {
472 auto fieldVal = fieldMap.find(read.getFieldName());
473 assert(fieldVal != fieldMap.end() &&
"Field read not associated with a value");
474 return fieldVal->second;
476 op->emitError(
"Value not mapped in R1CS lowering").report();
479 return valueMap.lookup(v);
482 auto linearTy = r1cs::LinearType::get(builder.getContext());
485 if (!lc.constant.isZero()) {
486 result = builder.create<r1cs::ConstOp>(
487 loc, linearTy, r1cs::FeltAttr::get(builder.getContext(), lc.constant)
491 for (
auto &[val, coeff] : lc.terms) {
492 Value mapped = getMapping(val);
495 Value lin = builder.create<r1cs::ToLinearOp>(loc, linearTy, mapped);
497 Value scaled = coeff == 1
499 : builder.create<r1cs::MulConstOp>(
500 loc, linearTy, lin, r1cs::FeltAttr::get(builder.getContext(), coeff)
507 result = builder.create<r1cs::AddOp>(loc, linearTy, result, scaled);
513 result = builder.create<r1cs::ConstOp>(
514 loc, r1cs::LinearType::get(builder.getContext()),
515 r1cs::FeltAttr::get(builder.getContext(), 0)
522 void buildAndEmitR1CS(
523 ModuleOp &moduleOp, StructDefOp &structDef, FuncDefOp &constrainFunc,
524 DenseMap<Value, unsigned> °reeMemo
526 SmallVector<R1CSConstraint, 16> constraints;
527 constrainFunc.walk([&](EmitEqualityOp eqOp) {
528 OpBuilder builder(eqOp);
529 R1CSConstraint eq = lowerEquationToR1CS(eqOp.
getLhs(), eqOp.
getRhs(), degreeMemo);
530 constraints.push_back(eq);
532 moduleOp->setAttr(
LANG_ATTR_NAME, StringAttr::get(moduleOp.getContext(),
"r1cs"));
533 Block &entryBlock = constrainFunc.
getBody().front();
535 Location loc = structDef.getLoc();
536 OpBuilder topBuilder(moduleOp.getBodyRegion());
539 bool hasPublicSignals =
false;
541 if (!field.getType().isa<FeltType>()) {
542 field.emitError(
"Only felt fields are supported as output signals").report();
546 if (field.isPublic()) {
547 hasPublicSignals =
true;
551 if (!hasPublicSignals) {
552 structDef.emitError(
"Struct should have at least one public output").report();
554 llvm::SmallVector<mlir::NamedAttribute> argAttrPairs;
556 for (
auto [i, arg] : llvm::enumerate(llvm::drop_begin(entryBlock.getArguments(), 1))) {
558 auto key = topBuilder.getStringAttr(std::to_string(i));
559 auto value = r1cs::PublicAttr::get(moduleOp.getContext());
560 argAttrPairs.emplace_back(key, value);
563 auto dictAttr = topBuilder.getDictionaryAttr(argAttrPairs);
565 topBuilder.create<r1cs::CircuitDefOp>(loc, structDef.
getSymName().str(), dictAttr);
567 Block *circuitBlock = circuit.addEntryBlock();
569 OpBuilder bodyBuilder = OpBuilder::atBlockEnd(circuitBlock);
572 for (
auto [i, arg] : llvm::enumerate(llvm::drop_begin(entryBlock.getArguments(), 1))) {
573 if (!arg.getType().isa<FeltType>()) {
574 constrainFunc.emitOpError(
"All input arguments must be of felt type").report();
578 auto blockArg = circuitBlock->addArgument(bodyBuilder.getType<r1cs::SignalType>(), loc);
579 valueMap.map(arg, blockArg);
584 DenseMap<StringRef, Value> fieldSignalMap;
585 uint32_t signalDefCntr = 0;
587 r1cs::PublicAttr pubAttr;
588 if (field.hasPublicAttr()) {
589 pubAttr = bodyBuilder.getAttr<r1cs::PublicAttr>();
591 auto defOp = bodyBuilder.create<r1cs::SignalDefOp>(
592 field.getLoc(), bodyBuilder.getType<r1cs::SignalType>(),
593 bodyBuilder.getUI32IntegerAttr(signalDefCntr), pubAttr
596 fieldSignalMap.insert({field.getName(), defOp.getOut()});
598 DenseMap<std::tuple<Value, Value, StringRef>, Value> binaryOpCache;
600 for (
const R1CSConstraint &constraint : constraints) {
601 Value aVal = emitLinearCombination(constraint.a, valueMap, fieldSignalMap, bodyBuilder, loc);
602 Value bVal = emitLinearCombination(constraint.b, valueMap, fieldSignalMap, bodyBuilder, loc);
603 Value cVal = emitLinearCombination(constraint.c, valueMap, fieldSignalMap, bodyBuilder, loc);
604 bodyBuilder.create<r1cs::ConstrainOp>(loc, aVal, bVal, cVal);
608 void getDependentDialects(mlir::DialectRegistry ®istry)
const override {
609 registry.insert<r1cs::R1CSDialect>();
612 void runOnOperation()
override {
613 ModuleOp moduleOp = getOperation();
615 moduleOp->getContext()->getLoadedDialect<r1cs::R1CSDialect>() &&
"R1CS dialect not loaded"
617 moduleOp.walk([
this, &moduleOp](StructDefOp structDef) {
620 if (!constrainFunc || !computeFunc) {
621 structDef.emitOpError(
"Missing compute or constrain function").report();
631 DenseMap<Value, unsigned> degreeMemo;
632 DenseMap<Value, Value> rewrites;
633 SmallVector<AuxAssignment> auxAssignments;
635 constrainFunc.walk([&](EmitEqualityOp eqOp) {
636 OpBuilder builder(eqOp);
637 Value lhs = normalizeForR1CS(
638 eqOp.
getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments, builder
640 Value rhs = normalizeForR1CS(
641 eqOp.
getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments, builder
644 unsigned degLhs = degreeMemo.lookup(lhs);
645 unsigned degRhs = degreeMemo.lookup(rhs);
648 if (degLhs == 2 && degRhs == 2) {
649 builder.setInsertionPoint(eqOp);
651 FieldDefOp auxField =
addAuxField(structDef, auxName);
652 Value self = constrainFunc.getArgument(0);
653 Value aux = builder.
create<FieldReadOp>(
654 eqOp.getLoc(), lhs.getType(), self, auxField.getNameAttr()
656 auto eqAux = builder.create<EmitEqualityOp>(eqOp.getLoc(), aux, lhs);
657 auxAssignments.push_back({auxName, lhs});
663 builder.create<EmitEqualityOp>(eqOp.getLoc(), lhs, rhs);
667 Block &computeBlock = computeFunc.
getBody().front();
668 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
670 DenseMap<Value, Value> rebuildMemo;
672 for (
const auto &assign : auxAssignments) {
674 builder.create<FieldWriteOp>(
675 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxFieldName), expr
678 buildAndEmitR1CS(moduleOp, structDef, constrainFunc, degreeMemo);
686 return std::make_unique<R1CSLoweringPass>();
This file defines helpers for manipulating APInts/APSInts for large numbers and operations over those...
#define R1CS_AUXILIARY_FIELD_PREFIX
::std::vector< FieldDefOp > getFieldDefs()
Get all FieldDefOp in this structure.
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
::mlir::Region & getBody()
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
Value getSelfValueFromCompute(FuncDefOp computeFunc)
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the IR language name.
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingAdd(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely add lhs and rhs, expanding the width of the result as necessary.
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::APSInt expandingMul(const llvm::APSInt &lhs, const llvm::APSInt &rhs)
Safely multiple lhs and rhs, expanding the width of the result as necessary.
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
std::unique_ptr< mlir::Pass > createR1CSLoweringPass()
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.