22#include <mlir/Dialect/Arith/IR/Arith.h>
23#include <mlir/Dialect/SCF/IR/SCF.h>
24#include <mlir/IR/BuiltinOps.h>
25#include <mlir/IR/Dominance.h>
27#include <llvm/ADT/DenseMap.h>
28#include <llvm/ADT/PostOrderIterator.h>
29#include <llvm/ADT/SmallVector.h>
35#define GEN_PASS_DEF_REDUNDANTOPERATIONELIMINATIONPASS
46#define DEBUG_TYPE "llzk-duplicate-op-elim"
50static auto EMPTY_OP_KEY =
reinterpret_cast<Operation *
>(1);
51static auto TOMBSTONE_OP_KEY =
reinterpret_cast<Operation *
>(2);
59class OperationComparator {
61 explicit OperationComparator(Operation *o) : op(o) {
62 if (op != EMPTY_OP_KEY && op != TOMBSTONE_OP_KEY) {
63 operands = SmallVector<Value>(op->getOperands());
67 OperationComparator(Operation *o,
const TranslationMap &m) : op(o) {
68 for (
auto operand : op->getOperands()) {
69 if (
auto it = m.find(operand); it != m.end()) {
70 operands.push_back(it->second);
72 operands.push_back(operand);
77 Operation *getOp()
const {
return op; }
79 const SmallVector<Value> &getOperands()
const {
return operands; }
81 bool isCommutative()
const {
return op->hasTrait<OpTrait::IsCommutative>(); }
83 friend bool operator==(
const OperationComparator &lhs,
const OperationComparator &rhs) {
84 if (lhs.op == EMPTY_OP_KEY || rhs.op == EMPTY_OP_KEY || lhs.op == TOMBSTONE_OP_KEY ||
85 rhs.op == TOMBSTONE_OP_KEY) {
86 return lhs.op == rhs.op;
89 if (lhs.op->getName() != rhs.op->getName()) {
94 auto dialectName = lhs.op->getDialect()->getNamespace();
95 if (dialectName == scf::SCFDialect::getDialectNamespace()) {
104 if (lhs.op->getAttrs() != rhs.op->getAttrs()) {
108 if (lhs.isCommutative()) {
110 lhs.operands.size() == 2 && rhs.operands.size() == 2,
111 "No known commutative ops have more than two arguments"
113 return (lhs.operands[0] == rhs.operands[0] && lhs.operands[1] == rhs.operands[1]) ||
114 (lhs.operands[0] == rhs.operands[1] && lhs.operands[1] == rhs.operands[0]);
118 return lhs.operands == rhs.operands;
123 SmallVector<Value> operands;
130template <>
struct DenseMapInfo<OperationComparator> {
131 static OperationComparator
getEmptyKey() {
return OperationComparator(EMPTY_OP_KEY); }
133 return OperationComparator(TOMBSTONE_OP_KEY);
136 if (oc.getOp() == EMPTY_OP_KEY || oc.getOp() == TOMBSTONE_OP_KEY) {
137 return hash_value(oc.getOp());
140 return hash_value(oc.getOp()->getName());
142 static bool isEqual(
const OperationComparator &lhs,
const OperationComparator &rhs) {
151class RedundantOperationEliminationPass
154 void runOnOperation()
override {
155 SymbolTableCollection symbolTables;
159 auto &cga = getAnalysis<CallGraphAnalysis>();
160 const llzk::CallGraph *callGraph = &cga.getCallGraph();
161 for (
auto it = llvm::po_begin(callGraph); it != llvm::po_end(callGraph); ++it) {
162 const llzk::CallGraphNode *node = *it;
169 bool isPurposelessConstrainFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
175 fn.walk([&](Operation *op) {
176 if (isa<EmitEqualityOp, EmitContainmentOp, AssertOp>(op)) {
178 return WalkResult::interrupt();
179 }
else if (
auto callOp = dyn_cast<CallOp>(op);
180 callOp && !callsPurposelessConstrainFunc(symbolTables, callOp)) {
182 return WalkResult::interrupt();
184 return WalkResult::advance();
189 bool callsPurposelessConstrainFunc(SymbolTableCollection &symbolTables,
CallOp call) {
191 return succeeded(callLookup) && isPurposelessConstrainFunc(symbolTables, callLookup->get());
194 void runOnFunc(SymbolTableCollection &symbolTables,
FuncDefOp fn) {
196 SmallVector<Operation *> redundantOps;
197 DenseSet<OperationComparator> uniqueOps;
198 DominanceInfo domInfo(fn);
200 auto unnecessaryOpCheck = [&](Operation *op) ->
bool {
201 if (
auto emiteq = dyn_cast<EmitEqualityOp>(op);
202 emiteq && emiteq.getLhs() == emiteq.getRhs()) {
203 redundantOps.push_back(op);
207 if (
auto callOp = dyn_cast<CallOp>(op);
208 callOp && callsPurposelessConstrainFunc(symbolTables, callOp)) {
209 redundantOps.push_back(op);
215 fn.walk([&](Operation *op) {
217 if (unnecessaryOpCheck(op)) {
218 return WalkResult::advance();
223 OperationComparator comp(op, map);
224 if (
auto it = uniqueOps.find(comp);
225 it != uniqueOps.end() && domInfo.dominates(it->getOp(), op)) {
226 redundantOps.push_back(op);
227 for (
unsigned opNum = 0; opNum < op->getNumResults(); opNum++) {
228 map[op->getResult(opNum)] = it->getOp()->getResult(opNum);
231 uniqueOps.insert(comp);
234 return WalkResult::advance();
238 std::deque<Value> operands;
240 for (
auto *op : redundantOps) {
241 LLVM_DEBUG(llvm::dbgs() <<
"Removing op: " << *op <<
'\n');
242 for (
auto result : op->getResults()) {
243 if (!result.getUsers().empty()) {
244 auto it = map.find(result);
246 it != map.end(),
"failed to find a replacement value for redundant operation result"
248 LLVM_DEBUG(llvm::dbgs() <<
"Replacing " << it->first <<
" with " << it->second <<
'\n');
249 result.replaceAllUsesWith(it->second);
252 for (Value operand : op->getOperands()) {
253 operands.push_back(operand);
262 DenseSet<Value> checkedOperands;
263 while (!operands.empty()) {
264 Value operand = operands.front();
265 operands.pop_front();
266 checkedOperands.insert(operand);
270 if (
auto *op = operand.getDefiningOp(); op && operand.getUsers().empty()) {
271 for (
auto parentOperand : op->getOperands()) {
272 if (checkedOperands.find(parentOperand) == checkedOperands.end()) {
273 operands.push_back(parentOperand);
276 LLVM_DEBUG(llvm::dbgs() <<
"Removing unused operand: " << operand <<
'\n');
286 return std::make_unique<RedundantOperationEliminationPass>();
bool isExternal() const
Returns true if this node is an external node.
llzk::function::FuncDefOp getCalledFunction() const
Returns the called function that the callable region represents.
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
std::unique_ptr< mlir::Pass > createRedundantOperationEliminationPass()
void ensure(bool condition, llvm::Twine errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
std::unordered_map< ConstrainRef, ConstrainRefLatticeValue, ConstrainRef::Hash > TranslationMap
static OperationComparator getTombstoneKey()
static unsigned getHashValue(const OperationComparator &oc)
static bool isEqual(const OperationComparator &lhs, const OperationComparator &rhs)
static OperationComparator getEmptyKey()