LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKLoweringUtils.cpp
Go to the documentation of this file.
1//===-- LLZKLoweringUtils.cpp --------------------------------*- C++ -*----===//
2//
3// Shared utility function implementations for LLZK lowering passes.
4//
5//===----------------------------------------------------------------------===//
6
8
9#include <mlir/IR/Block.h>
10#include <mlir/IR/Builders.h>
11#include <mlir/IR/BuiltinOps.h>
12#include <mlir/IR/Operation.h>
13#include <mlir/Support/LogicalResult.h>
14
15#include <llvm/ADT/SmallVector.h>
16#include <llvm/Support/raw_ostream.h>
17
18using namespace mlir;
19using namespace llzk;
20using namespace llzk::felt;
21using namespace llzk::function;
22using namespace llzk::component;
23using namespace llzk::constrain;
24
25namespace llzk {
26
28 // Get the single block of the function body
29 Region &body = computeFunc.getBody();
30 assert(!body.empty() && "compute() function body is empty");
31 Block &block = body.back();
32
33 // The terminator should be the return op
34 Operation *terminator = block.getTerminator();
35 assert(terminator && "compute() function has no terminator");
36 auto retOp = dyn_cast<ReturnOp>(terminator);
37 if (!retOp) {
38 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
39 << terminator->getName() << "'\n";
40 llvm_unreachable("compute() function must end with ReturnOp");
41 }
42 return retOp.getOperands().front();
43}
44
46 Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap<Value, Value> &memo
47) {
48 if (auto it = memo.find(val); it != memo.end()) {
49 return it->second;
50 }
51
52 if (auto barg = val.dyn_cast<BlockArgument>()) {
53 unsigned index = barg.getArgNumber();
54 Value mapped = computeFunc.getArgument(index - 1);
55 return memo[val] = mapped;
56 }
57
58 if (auto readOp = val.getDefiningOp<FieldReadOp>()) {
59 Value self = getSelfValueFromCompute(computeFunc);
60 Value rebuilt = builder.create<FieldReadOp>(
61 readOp.getLoc(), readOp.getType(), self, readOp.getFieldNameAttr().getAttr()
62 );
63 return memo[val] = rebuilt;
64 }
65
66 if (auto add = val.getDefiningOp<AddFeltOp>()) {
67 Value lhs = rebuildExprInCompute(add.getLhs(), computeFunc, builder, memo);
68 Value rhs = rebuildExprInCompute(add.getRhs(), computeFunc, builder, memo);
69 return memo[val] = builder.create<AddFeltOp>(add.getLoc(), add.getType(), lhs, rhs);
70 }
71
72 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
73 Value lhs = rebuildExprInCompute(sub.getLhs(), computeFunc, builder, memo);
74 Value rhs = rebuildExprInCompute(sub.getRhs(), computeFunc, builder, memo);
75 return memo[val] = builder.create<SubFeltOp>(sub.getLoc(), sub.getType(), lhs, rhs);
76 }
77
78 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
79 Value lhs = rebuildExprInCompute(mul.getLhs(), computeFunc, builder, memo);
80 Value rhs = rebuildExprInCompute(mul.getRhs(), computeFunc, builder, memo);
81 return memo[val] = builder.create<MulFeltOp>(mul.getLoc(), mul.getType(), lhs, rhs);
82 }
83
84 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
85 Value operand = rebuildExprInCompute(neg.getOperand(), computeFunc, builder, memo);
86 return memo[val] = builder.create<NegFeltOp>(neg.getLoc(), neg.getType(), operand);
87 }
88
89 if (auto div = val.getDefiningOp<DivFeltOp>()) {
90 Value lhs = rebuildExprInCompute(div.getLhs(), computeFunc, builder, memo);
91 Value rhs = rebuildExprInCompute(div.getRhs(), computeFunc, builder, memo);
92 return memo[val] = builder.create<DivFeltOp>(div.getLoc(), div.getType(), lhs, rhs);
93 }
94
95 if (auto c = val.getDefiningOp<FeltConstantOp>()) {
96 return memo[val] = builder.create<FeltConstantOp>(c.getLoc(), c.getValue());
97 }
98
99 llvm::errs() << "Unhandled op in rebuildExprInCompute: " << val << '\n';
100 llvm_unreachable("Unsupported op kind");
101}
102
103LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix) {
104 bool conflictFound = false;
105
106 structDef.walk([&conflictFound, &prefix](FieldDefOp fieldDefOp) {
107 if (fieldDefOp.getName().starts_with(prefix)) {
108 (fieldDefOp.emitError() << "Field name '" << fieldDefOp.getName()
109 << "' conflicts with reserved prefix '" << prefix << '\'')
110 .report();
111 conflictFound = true;
112 }
113 });
114
115 return failure(conflictFound);
116}
117
118void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) {
119 assert(afterOp && "afterOp must be a valid Operation*");
120
121 for (auto &use : llvm::make_early_inc_range(oldVal.getUses())) {
122 Operation *user = use.getOwner();
123
124 // Skip uses that are:
125 // - Before afterOp in the same block.
126 // - Inside afterOp itself.
127 if ((user->getBlock() == afterOp->getBlock()) &&
128 (user == afterOp || user->isBeforeInBlock(afterOp))) {
129 continue;
130 }
131
132 // Replace this use of oldVal with newVal.
133 use.set(newVal);
134 }
135}
136
137FieldDefOp addAuxField(StructDefOp structDef, StringRef name) {
138 OpBuilder builder(structDef);
139 builder.setInsertionPointToEnd(&structDef.getBody().back());
140 return builder.create<FieldDefOp>(
141 structDef.getLoc(), builder.getStringAttr(name), builder.getType<FeltType>()
142 );
143}
144
145unsigned getFeltDegree(Value val, DenseMap<Value, unsigned> &memo) {
146 if (auto it = memo.find(val); it != memo.end()) {
147 return it->second;
148 }
149
150 if (isa<FeltConstantOp>(val.getDefiningOp())) {
151 return memo[val] = 0;
152 }
153 if (isa<FeltNonDetOp, FieldReadOp>(val.getDefiningOp()) || isa<BlockArgument>(val)) {
154 return memo[val] = 1;
155 }
156
157 if (auto add = val.getDefiningOp<AddFeltOp>()) {
158 return memo[val] =
159 std::max(getFeltDegree(add.getLhs(), memo), getFeltDegree(add.getRhs(), memo));
160 }
161 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
162 return memo[val] =
163 std::max(getFeltDegree(sub.getLhs(), memo), getFeltDegree(sub.getRhs(), memo));
164 }
165 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
166 return memo[val] = getFeltDegree(mul.getLhs(), memo) + getFeltDegree(mul.getRhs(), memo);
167 }
168 if (auto div = val.getDefiningOp<DivFeltOp>()) {
169 return memo[val] = getFeltDegree(div.getLhs(), memo) + getFeltDegree(div.getRhs(), memo);
170 }
171 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
172 return memo[val] = getFeltDegree(neg.getOperand(), memo);
173 }
174
175 llvm::errs() << "Unhandled felt op in degree computation: " << val << '\n';
176 llvm_unreachable("Unhandled op in getFeltDegree");
177}
178
179} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
MlirStringRef name
::mlir::Region & getBody()
Definition Ops.cpp.inc:1810
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
Definition Ops.cpp:143
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:705
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
Value getSelfValueFromCompute(FuncDefOp computeFunc)
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
unsigned getFeltDegree(Value val, DenseMap< Value, unsigned > &memo)