LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKRedundantOperationEliminationPass.cpp
Go to the documentation of this file.
1//===-- LLZKRedundantOperationEliminationPass.cpp ---------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
21
22#include <mlir/Dialect/Arith/IR/Arith.h>
23#include <mlir/Dialect/SCF/IR/SCF.h>
24#include <mlir/IR/BuiltinOps.h>
25#include <mlir/IR/Dominance.h>
26
27#include <llvm/ADT/DenseMap.h>
28#include <llvm/ADT/PostOrderIterator.h>
29#include <llvm/ADT/SmallVector.h>
30
31#include <deque>
32
33// Include the generated base pass class definitions.
34namespace llzk {
35#define GEN_PASS_DEF_REDUNDANTOPERATIONELIMINATIONPASS
37} // namespace llzk
38
39using namespace mlir;
40using namespace llzk;
41using namespace llzk::boolean;
42using namespace llzk::component;
43using namespace llzk::constrain;
44using namespace llzk::function;
45
46#define DEBUG_TYPE "llzk-duplicate-op-elim"
47
48namespace {
49
50static auto EMPTY_OP_KEY = reinterpret_cast<Operation *>(1);
51static auto TOMBSTONE_OP_KEY = reinterpret_cast<Operation *>(2);
52
53// Maps original -> replacement value
54using TranslationMap = DenseMap<Value, Value>;
55
59class OperationComparator {
60public:
61 explicit OperationComparator(Operation *o) : op(o) {
62 if (op != EMPTY_OP_KEY && op != TOMBSTONE_OP_KEY) {
63 operands = SmallVector<Value>(op->getOperands());
64 }
65 }
66
67 OperationComparator(Operation *o, const TranslationMap &m) : op(o) {
68 for (auto operand : op->getOperands()) {
69 if (auto it = m.find(operand); it != m.end()) {
70 operands.push_back(it->second);
71 } else {
72 operands.push_back(operand);
73 }
74 }
75 }
76
77 Operation *getOp() const { return op; }
78
79 const SmallVector<Value> &getOperands() const { return operands; }
80
81 bool isCommutative() const { return op->hasTrait<OpTrait::IsCommutative>(); }
82
83 friend bool operator==(const OperationComparator &lhs, const OperationComparator &rhs) {
84 if (lhs.op == EMPTY_OP_KEY || rhs.op == EMPTY_OP_KEY || lhs.op == TOMBSTONE_OP_KEY ||
85 rhs.op == TOMBSTONE_OP_KEY) {
86 return lhs.op == rhs.op;
87 }
88
89 if (lhs.op->getName() != rhs.op->getName()) {
90 return false;
91 }
92
93 // uninterested in operating over control-flow ops
94 auto dialectName = lhs.op->getDialect()->getNamespace();
95 if (dialectName == scf::SCFDialect::getDialectNamespace()) {
96 return false;
97 }
98
99 // This may be overly restrictive in some cases, but without knowing what
100 // potential future attributes we may have, it's safer to assume that
101 // unequal attributes => unequal operations.
102 // This covers constant operations too, as the constant is an attribute,
103 // not an operand.
104 if (lhs.op->getAttrs() != rhs.op->getAttrs()) {
105 return false;
107 // For commutative operations, just check if the operands contain the same set in any order
108 if (lhs.isCommutative()) {
110 lhs.operands.size() == 2 && rhs.operands.size() == 2,
111 "No known commutative ops have more than two arguments"
112 );
113 return (lhs.operands[0] == rhs.operands[0] && lhs.operands[1] == rhs.operands[1]) ||
114 (lhs.operands[0] == rhs.operands[1] && lhs.operands[1] == rhs.operands[0]);
115 }
116
117 // The default case requires an exact match per argument
118 return lhs.operands == rhs.operands;
119 }
121private:
122 Operation *op;
123 SmallVector<Value> operands;
124};
125
126} // namespace
127
128namespace llvm {
129
130template <> struct DenseMapInfo<OperationComparator> {
131 static OperationComparator getEmptyKey() { return OperationComparator(EMPTY_OP_KEY); }
132 static inline OperationComparator getTombstoneKey() {
133 return OperationComparator(TOMBSTONE_OP_KEY);
134 }
135 static unsigned getHashValue(const OperationComparator &oc) {
136 if (oc.getOp() == EMPTY_OP_KEY || oc.getOp() == TOMBSTONE_OP_KEY) {
137 return hash_value(oc.getOp());
138 }
139 // Just hash on name to force more thorough equality checks by operation type.
140 return hash_value(oc.getOp()->getName());
141 }
142 static bool isEqual(const OperationComparator &lhs, const OperationComparator &rhs) {
143 return lhs == rhs;
144 }
145};
146
147} // namespace llvm
148
149namespace {
150
151class RedundantOperationEliminationPass
152 : public llzk::impl::RedundantOperationEliminationPassBase<RedundantOperationEliminationPass> {
153
154 void runOnOperation() override {
155 SymbolTableCollection symbolTables;
156 // Traverse functions from the bottom of the call graph up.
157 // This way, we may create empty constrain functions to which we can eliminate
158 // calls.
159 auto &cga = getAnalysis<CallGraphAnalysis>();
160 const llzk::CallGraph *callGraph = &cga.getCallGraph();
161 for (auto it = llvm::po_begin(callGraph); it != llvm::po_end(callGraph); ++it) {
162 const llzk::CallGraphNode *node = *it;
163 if (!node->isExternal()) {
164 runOnFunc(symbolTables, node->getCalledFunction());
165 }
166 }
167 }
168
169 bool isPurposelessConstrainFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
170 if (!fn.isStructConstrain()) {
171 return false;
172 }
173
174 bool res = true;
175 fn.walk([&](Operation *op) {
176 if (isa<EmitEqualityOp, EmitContainmentOp, AssertOp>(op)) {
177 res = false;
178 return WalkResult::interrupt();
179 } else if (auto callOp = dyn_cast<CallOp>(op);
180 callOp && !callsPurposelessConstrainFunc(symbolTables, callOp)) {
181 res = false;
182 return WalkResult::interrupt();
183 }
184 return WalkResult::advance();
185 });
186 return res;
187 }
188
189 bool callsPurposelessConstrainFunc(SymbolTableCollection &symbolTables, CallOp call) {
190 auto callLookup = resolveCallable<FuncDefOp>(symbolTables, call);
191 return succeeded(callLookup) && isPurposelessConstrainFunc(symbolTables, callLookup->get());
192 }
193
194 void runOnFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
195 TranslationMap map;
196 SmallVector<Operation *> redundantOps;
197 DenseSet<OperationComparator> uniqueOps;
198 DominanceInfo domInfo(fn);
199
200 auto unnecessaryOpCheck = [&](Operation *op) -> bool {
201 if (auto emiteq = dyn_cast<EmitEqualityOp>(op);
202 emiteq && emiteq.getLhs() == emiteq.getRhs()) {
203 redundantOps.push_back(op);
204 return true;
205 }
206
207 if (auto callOp = dyn_cast<CallOp>(op);
208 callOp && callsPurposelessConstrainFunc(symbolTables, callOp)) {
209 redundantOps.push_back(op);
210 return true;
211 }
212 return false;
213 };
214
215 fn.walk([&](Operation *op) {
216 // Case 1: The operation itself is unnecessary.
217 if (unnecessaryOpCheck(op)) {
218 return WalkResult::advance();
219 }
220
221 // Case 2: An equivalent operation A has already been performed before
222 // the current operation B and A dominates B.
223 OperationComparator comp(op, map);
224 if (auto it = uniqueOps.find(comp);
225 it != uniqueOps.end() && domInfo.dominates(it->getOp(), op)) {
226 redundantOps.push_back(op);
227 for (unsigned opNum = 0; opNum < op->getNumResults(); opNum++) {
228 map[op->getResult(opNum)] = it->getOp()->getResult(opNum);
229 }
230 } else {
231 uniqueOps.insert(comp);
232 }
233
234 return WalkResult::advance();
235 });
236
237 // Track the operands of removed ops.
238 std::deque<Value> operands;
239
240 for (auto *op : redundantOps) {
241 LLVM_DEBUG(llvm::dbgs() << "Removing op: " << *op << '\n');
242 for (auto result : op->getResults()) {
243 if (!result.getUsers().empty()) {
244 auto it = map.find(result);
245 ensure(
246 it != map.end(), "failed to find a replacement value for redundant operation result"
247 );
248 LLVM_DEBUG(llvm::dbgs() << "Replacing " << it->first << " with " << it->second << '\n');
249 result.replaceAllUsesWith(it->second);
250 }
251 }
252 for (Value operand : op->getOperands()) {
253 operands.push_back(operand);
254 }
255 op->erase();
256 }
257
258 // Check if any of the operands are unused. If so, remove them, and check
259 // their operands until all operands have been checked.
260
261 // Make sure operands aren't freed multiple times
262 DenseSet<Value> checkedOperands;
263 while (!operands.empty()) {
264 Value operand = operands.front();
265 operands.pop_front();
266 checkedOperands.insert(operand);
267
268 // We only want to remove operands that are defined by an operation and
269 // are not block arguments.
270 if (auto *op = operand.getDefiningOp(); op && operand.getUsers().empty()) {
271 for (auto parentOperand : op->getOperands()) {
272 if (checkedOperands.find(parentOperand) == checkedOperands.end()) {
273 operands.push_back(parentOperand);
274 }
275 }
276 LLVM_DEBUG(llvm::dbgs() << "Removing unused operand: " << operand << '\n');
277 op->erase();
278 }
279 }
280 }
281};
282
283} // namespace
284
286 return std::make_unique<RedundantOperationEliminationPass>();
287};
bool isExternal() const
Returns true if this node is an external node.
Definition CallGraph.cpp:38
llzk::function::FuncDefOp getCalledFunction() const
Returns the called function that the callable region represents.
Definition CallGraph.cpp:47
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
Definition Ops.h.inc:625
std::unique_ptr< mlir::Pass > createRedundantOperationEliminationPass()
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:32
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
std::unordered_map< ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash > TranslationMap
static unsigned getHashValue(const OperationComparator &oc)
static bool isEqual(const OperationComparator &lhs, const OperationComparator &rhs)