LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKLoweringUtils.cpp
Go to the documentation of this file.
1//===-- LLZKLoweringUtils.cpp --------------------------------*- C++ -*----===//
2//
3// Shared utility function implementations for LLZK lowering passes.
4//
5//===----------------------------------------------------------------------===//
6
8
9#include <mlir/IR/Block.h>
10#include <mlir/IR/Builders.h>
11#include <mlir/IR/BuiltinOps.h>
12#include <mlir/IR/Operation.h>
13#include <mlir/Support/LogicalResult.h>
14
15#include <llvm/ADT/SmallVector.h>
16#include <llvm/Support/raw_ostream.h>
17
18using namespace mlir;
19using namespace llzk;
20using namespace llzk::felt;
21using namespace llzk::function;
22using namespace llzk::component;
23using namespace llzk::constrain;
24
25namespace llzk {
26
28 Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap<Value, Value> &memo
29) {
30 if (auto it = memo.find(val); it != memo.end()) {
31 return it->second;
32 }
33
34 if (auto barg = val.dyn_cast<BlockArgument>()) {
35 unsigned index = barg.getArgNumber();
36 Value mapped = computeFunc.getArgument(index - 1);
37 return memo[val] = mapped;
38 }
39
40 if (auto readOp = val.getDefiningOp<FieldReadOp>()) {
41 Value self = computeFunc.getSelfValueFromCompute();
42 Value rebuilt = builder.create<FieldReadOp>(
43 readOp.getLoc(), readOp.getType(), self, readOp.getFieldNameAttr().getAttr()
44 );
45 return memo[val] = rebuilt;
46 }
47
48 if (auto add = val.getDefiningOp<AddFeltOp>()) {
49 Value lhs = rebuildExprInCompute(add.getLhs(), computeFunc, builder, memo);
50 Value rhs = rebuildExprInCompute(add.getRhs(), computeFunc, builder, memo);
51 return memo[val] = builder.create<AddFeltOp>(add.getLoc(), add.getType(), lhs, rhs);
52 }
53
54 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
55 Value lhs = rebuildExprInCompute(sub.getLhs(), computeFunc, builder, memo);
56 Value rhs = rebuildExprInCompute(sub.getRhs(), computeFunc, builder, memo);
57 return memo[val] = builder.create<SubFeltOp>(sub.getLoc(), sub.getType(), lhs, rhs);
58 }
59
60 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
61 Value lhs = rebuildExprInCompute(mul.getLhs(), computeFunc, builder, memo);
62 Value rhs = rebuildExprInCompute(mul.getRhs(), computeFunc, builder, memo);
63 return memo[val] = builder.create<MulFeltOp>(mul.getLoc(), mul.getType(), lhs, rhs);
64 }
65
66 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
67 Value operand = rebuildExprInCompute(neg.getOperand(), computeFunc, builder, memo);
68 return memo[val] = builder.create<NegFeltOp>(neg.getLoc(), neg.getType(), operand);
69 }
70
71 if (auto div = val.getDefiningOp<DivFeltOp>()) {
72 Value lhs = rebuildExprInCompute(div.getLhs(), computeFunc, builder, memo);
73 Value rhs = rebuildExprInCompute(div.getRhs(), computeFunc, builder, memo);
74 return memo[val] = builder.create<DivFeltOp>(div.getLoc(), div.getType(), lhs, rhs);
75 }
76
77 if (auto c = val.getDefiningOp<FeltConstantOp>()) {
78 return memo[val] = builder.create<FeltConstantOp>(c.getLoc(), c.getValue());
79 }
80
81 llvm::errs() << "Unhandled op in rebuildExprInCompute: " << val << '\n';
82 llvm_unreachable("Unsupported op kind");
83}
84
85LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix) {
86 bool conflictFound = false;
87
88 structDef.walk([&conflictFound, &prefix](FieldDefOp fieldDefOp) {
89 if (fieldDefOp.getName().starts_with(prefix)) {
90 (fieldDefOp.emitError() << "Field name '" << fieldDefOp.getName()
91 << "' conflicts with reserved prefix '" << prefix << '\'')
92 .report();
93 conflictFound = true;
94 }
95 });
96
97 return failure(conflictFound);
98}
99
100void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) {
101 assert(afterOp && "afterOp must be a valid Operation*");
102
103 for (auto &use : llvm::make_early_inc_range(oldVal.getUses())) {
104 Operation *user = use.getOwner();
105
106 // Skip uses that are:
107 // - Before afterOp in the same block.
108 // - Inside afterOp itself.
109 if ((user->getBlock() == afterOp->getBlock()) &&
110 (user == afterOp || user->isBeforeInBlock(afterOp))) {
111 continue;
112 }
113
114 // Replace this use of oldVal with newVal.
115 use.set(newVal);
116 }
117}
118
119FieldDefOp addAuxField(StructDefOp structDef, StringRef name) {
120 OpBuilder builder(structDef);
121 builder.setInsertionPointToEnd(&structDef.getBody().back());
122 return builder.create<FieldDefOp>(
123 structDef.getLoc(), builder.getStringAttr(name), builder.getType<FeltType>()
124 );
125}
126
127unsigned getFeltDegree(Value val, DenseMap<Value, unsigned> &memo) {
128 if (auto it = memo.find(val); it != memo.end()) {
129 return it->second;
130 }
131
132 if (isa<FeltConstantOp>(val.getDefiningOp())) {
133 return memo[val] = 0;
134 }
135 if (isa<FeltNonDetOp, FieldReadOp>(val.getDefiningOp()) || isa<BlockArgument>(val)) {
136 return memo[val] = 1;
137 }
138
139 if (auto add = val.getDefiningOp<AddFeltOp>()) {
140 return memo[val] =
141 std::max(getFeltDegree(add.getLhs(), memo), getFeltDegree(add.getRhs(), memo));
142 }
143 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
144 return memo[val] =
145 std::max(getFeltDegree(sub.getLhs(), memo), getFeltDegree(sub.getRhs(), memo));
146 }
147 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
148 return memo[val] = getFeltDegree(mul.getLhs(), memo) + getFeltDegree(mul.getRhs(), memo);
149 }
150 if (auto div = val.getDefiningOp<DivFeltOp>()) {
151 return memo[val] = getFeltDegree(div.getLhs(), memo) + getFeltDegree(div.getRhs(), memo);
152 }
153 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
154 return memo[val] = getFeltDegree(neg.getOperand(), memo);
155 }
156
157 llvm::errs() << "Unhandled felt op in degree computation: " << val << '\n';
158 llvm_unreachable("Unhandled op in getFeltDegree");
159}
160
161} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
MlirStringRef name
::mlir::Region & getBody()
Definition Ops.cpp.inc:1810
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:314
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
LogicalResult checkForAuxFieldConflicts(StructDefOp structDef, StringRef prefix)
FieldDefOp addAuxField(StructDefOp structDef, StringRef name)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
unsigned getFeltDegree(Value val, DenseMap< Value, unsigned > &memo)