22#include <mlir/IR/BuiltinOps.h>
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
34#define GEN_PASS_DEF_REDUNDANTREADANDWRITEELIMINATIONPASS
45#define DEBUG_TYPE "llzk-redundant-read-write-pass"
53 explicit ReferenceID(Value v) {
55 if (v.getImpl() ==
reinterpret_cast<mlir::detail::ValueImpl *
>(1) ||
56 v.getImpl() ==
reinterpret_cast<mlir::detail::ValueImpl *
>(2)) {
58 }
else if (
auto constVal = dyn_cast_if_present<FeltConstantOp>(v.getDefiningOp())) {
59 identifier = constVal.getValue().getValue();
60 }
else if (
auto constIdxVal = dyn_cast_if_present<arith::ConstantIndexOp>(v.getDefiningOp())) {
61 identifier = llvm::cast<IntegerAttr>(constIdxVal.getValue()).getValue();
66 explicit ReferenceID(FlatSymbolRefAttr s) : identifier(s) {}
67 explicit ReferenceID(APInt i) : identifier(i) {}
68 explicit ReferenceID(
unsigned i) : identifier(APInt(64, i)) {}
70 bool isValue()
const {
return std::holds_alternative<Value>(identifier); }
71 bool isSymbol()
const {
return std::holds_alternative<FlatSymbolRefAttr>(identifier); }
72 bool isConst()
const {
return std::holds_alternative<APInt>(identifier); }
74 Value getValue()
const {
75 ensure(isValue(),
"does not hold Value");
76 return std::get<Value>(identifier);
79 FlatSymbolRefAttr getSymbol()
const {
80 ensure(isSymbol(),
"does not hold symbol");
81 return std::get<FlatSymbolRefAttr>(identifier);
84 APInt getConst()
const {
85 ensure(isConst(),
"does not hold const");
86 return std::get<APInt>(identifier);
89 void print(raw_ostream &os)
const {
90 if (
auto v = std::get_if<Value>(&identifier)) {
91 if (
auto opres = dyn_cast<OpResult>(*v)) {
92 os <<
'%' << opres.getResultNumber();
96 }
else if (
auto s = std::get_if<FlatSymbolRefAttr>(&identifier)) {
99 os << std::get<APInt>(identifier);
103 friend bool operator==(
const ReferenceID &lhs,
const ReferenceID &rhs) {
104 return lhs.identifier == rhs.identifier;
107 friend raw_ostream &
operator<<(raw_ostream &os,
const ReferenceID &
id) {
117 std::variant<FlatSymbolRefAttr, APInt, Value> identifier;
125template <>
struct DenseMapInfo<ReferenceID> {
127 return ReferenceID(mlir::Value(
reinterpret_cast<mlir::detail::ValueImpl *
>(1)));
130 return ReferenceID(mlir::Value(
reinterpret_cast<mlir::detail::ValueImpl *
>(2)));
134 return hash_value(r.getValue());
135 }
else if (r.isSymbol()) {
136 return hash_value(r.getSymbol());
138 return hash_value(r.getConst());
140 static bool isEqual(
const ReferenceID &lhs,
const ReferenceID &rhs) {
return lhs == rhs; }
167 template <
typename IdType>
static std::shared_ptr<ReferenceNode>
create(IdType
id, Value v) {
168 ReferenceNode n(
id, v);
170 return std::make_shared<ReferenceNode>(std::move(n));
175 std::shared_ptr<ReferenceNode> clone(
bool withChildren =
true)
const {
176 ReferenceNode copy(identifier, storedValue);
177 copy.updateLastWrite(lastWrite);
179 for (
const auto &[
id, child] : children) {
180 copy.children[id] = child->clone(withChildren);
183 return std::make_shared<ReferenceNode>(std::move(copy));
186 template <
typename IdType>
187 std::shared_ptr<ReferenceNode>
188 createChild(IdType
id, Value storedVal, std::shared_ptr<ReferenceNode> valTree =
nullptr) {
189 std::shared_ptr<ReferenceNode> child =
create(
id, storedVal);
190 child->setCurrentValue(storedVal, valTree);
191 children[child->identifier] = child;
197 template <
typename IdType> std::shared_ptr<ReferenceNode> getChild(IdType
id)
const {
198 auto it = children.find(ReferenceID(
id));
199 if (it != children.end()) {
208 template <
typename IdType>
209 std::shared_ptr<ReferenceNode> getOrCreateChild(IdType
id, Value storedVal =
nullptr) {
210 auto it = children.find(ReferenceID(
id));
211 if (it != children.end()) {
214 return createChild(
id, storedVal);
219 Operation *updateLastWrite(Operation *writeOp) {
220 Operation *old = lastWrite;
225 void setCurrentValue(Value v, std::shared_ptr<ReferenceNode> valTree =
nullptr) {
227 if (valTree !=
nullptr) {
230 children = valTree->children;
234 void invalidateChildren() { children.clear(); }
236 bool isLeaf()
const {
return children.empty(); }
238 Value getStoredValue()
const {
return storedValue; }
240 bool hasStoredValue()
const {
return storedValue !=
nullptr; }
242 void print(raw_ostream &os,
int indent = 0)
const {
243 os.indent(indent) <<
'[' << identifier;
244 if (storedValue !=
nullptr) {
245 os <<
" => " << storedValue;
248 if (!children.empty()) {
250 for (
auto &[_, child] : children) {
251 child->print(os, indent + 4);
254 os.indent(indent) <<
'}';
258 friend raw_ostream &
operator<<(raw_ostream &os,
const ReferenceNode &r) {
265 topLevelEq(
const std::shared_ptr<ReferenceNode> &lhs,
const std::shared_ptr<ReferenceNode> &rhs) {
266 return lhs->identifier == rhs->identifier && lhs->storedValue == rhs->storedValue &&
267 lhs->lastWrite == rhs->lastWrite;
270 friend std::shared_ptr<ReferenceNode> greatestCommonSubtree(
271 const std::shared_ptr<ReferenceNode> &lhs,
const std::shared_ptr<ReferenceNode> &rhs
273 if (!topLevelEq(lhs, rhs)) {
276 auto res = lhs->clone(
false);
278 for (
auto &[
id, lhsChild] : lhs->children) {
279 if (
auto it = rhs->children.find(
id); it != rhs->children.end()) {
280 auto &rhsChild = it->second;
281 if (
auto gcs = greatestCommonSubtree(lhsChild, rhsChild)) {
282 res->children[id] = gcs;
290 ReferenceID identifier;
291 mlir::Value storedValue;
292 Operation *lastWrite;
293 DenseMap<ReferenceID, std::shared_ptr<ReferenceNode>> children;
295 template <
typename IdType>
296 ReferenceNode(IdType
id, Value initialVal)
297 : identifier(id), storedValue(initialVal), lastWrite(nullptr), children() {}
300using ValueMap = DenseMap<mlir::Value, std::shared_ptr<ReferenceNode>>;
302ValueMap intersect(
const ValueMap &lhs,
const ValueMap &rhs) {
304 for (
auto &[
id, lhsValTree] : lhs) {
305 if (
auto it = rhs.find(
id); it != rhs.end()) {
306 auto &rhsValTree = it->second;
307 res[id] = greatestCommonSubtree(lhsValTree, rhsValTree);
315ValueMap cloneValueMap(
const ValueMap &orig) {
317 for (
auto &[
id, tree] : orig) {
318 res[id] = tree->clone();
323class RedundantReadAndWriteEliminationPass
325 RedundantReadAndWriteEliminationPass> {
331 void runOnOperation()
override {
332 getOperation().walk([&](FuncDefOp fn) { runOnFunc(fn); });
337 void runOnFunc(FuncDefOp fn) {
343 LLVM_DEBUG(llvm::dbgs() <<
"Running on " << fn.getName() <<
'\n');
346 DenseMap<Value, Value> replacementMap;
348 SmallVector<Value> readVals;
351 SmallVector<Operation *> redundantWrites;
355 for (
auto arg : fn.getArguments()) {
356 initState[arg] = ReferenceNode::create(arg, arg);
360 *fn.
getCallableRegion(), std::move(initState), replacementMap, readVals, redundantWrites
365 for (
auto &[orig, replace] : replacementMap) {
366 LLVM_DEBUG(llvm::dbgs() <<
"replacing " << orig <<
" with " << orig <<
'\n');
367 orig.replaceAllUsesWith(replace);
371 for (
auto *writeOp : redundantWrites) {
372 LLVM_DEBUG(llvm::dbgs() <<
"erase write: " << *writeOp <<
'\n');
378 for (
auto it = readVals.rbegin(); it != readVals.rend(); it++) {
380 if (readVal.use_empty()) {
381 LLVM_DEBUG(llvm::dbgs() <<
"erase read: " << readVal <<
'\n');
382 readVal.getDefiningOp()->erase();
387 ValueMap runOnRegion(
388 Region &r, ValueMap &&initState, DenseMap<Value, Value> &replacementMap,
389 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
392 DenseMap<Block *, ValueMap> endStates;
394 endStates[
nullptr] = initState;
395 auto getBlockState = [&endStates](Block *blockPtr) {
396 auto it = endStates.find(blockPtr);
397 ensure(it != endStates.end(),
"unknown end state means we have an unsupported backedge");
398 return cloneValueMap(it->second);
400 std::deque<Block *> frontier;
401 frontier.push_back(&r.front());
402 DenseSet<Block *> visited;
404 SmallVector<std::reference_wrapper<const ValueMap>> terminalStates;
406 while (!frontier.empty()) {
407 Block *currentBlock = frontier.front();
408 frontier.pop_front();
409 visited.insert(currentBlock);
412 ValueMap currentState;
413 auto it = currentBlock->pred_begin();
414 auto itEnd = currentBlock->pred_end();
417 currentState = getBlockState(
nullptr);
419 currentState = getBlockState(*it);
423 for (it++; it != itEnd; it++) {
424 currentState = intersect(currentState, getBlockState(*it));
429 auto endState = runOnBlock(
430 *currentBlock, std::move(currentState), replacementMap, readVals, redundantWrites
436 ensure(endStates.find(currentBlock) == endStates.end(),
"backedge");
437 endStates[currentBlock] = std::move(endState);
440 if (currentBlock->hasNoSuccessors()) {
441 terminalStates.push_back(endStates[currentBlock]);
443 for (Block *succ : currentBlock->getSuccessors()) {
444 if (visited.find(succ) == visited.end()) {
445 frontier.push_back(succ);
452 ensure(!terminalStates.empty(),
"computed no states");
453 auto finalState = terminalStates.front().get();
454 for (
auto it = terminalStates.begin() + 1; it != terminalStates.end(); it++) {
455 finalState = intersect(finalState, it->get());
461 Block &b, ValueMap &&state, DenseMap<Value, Value> &replacementMap,
462 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
464 for (Operation &op : b) {
465 runOperation(&op, state, replacementMap, readVals, redundantWrites);
469 if (!op.getRegions().empty()) {
470 SmallVector<ValueMap> regionStates;
471 for (Region ®ion : op.getRegions()) {
473 runOnRegion(region, cloneValueMap(state), replacementMap, readVals, redundantWrites);
474 regionStates.push_back(regionState);
477 ValueMap finalState = regionStates.front();
478 for (
auto it = regionStates.begin() + 1; it != regionStates.end(); it++) {
479 finalState = intersect(finalState, *it);
481 state = std::move(finalState);
484 return std::move(state);
495 Operation *op, ValueMap &state, DenseMap<Value, Value> &replacementMap,
496 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
501 auto translate = [&replacementMap](Value v) {
502 if (
auto it = replacementMap.find(v); it != replacementMap.end()) {
509 auto tryGetValTree = [&state](Value v) -> std::shared_ptr<ReferenceNode> {
510 if (
auto it = state.find(v); it != state.end()) {
518 auto doArrayReadLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass readarr) {
519 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(readarr.getArrRef()));
521 for (Value origIdx : readarr.getIndices()) {
522 Value idxVal = translate(origIdx);
523 currValTree = currValTree->getOrCreateChild(idxVal);
526 Value resVal = readarr.getResult();
527 if (!currValTree->hasStoredValue()) {
528 currValTree->setCurrentValue(resVal);
531 if (currValTree->getStoredValue() != resVal) {
533 llvm::dbgs() << readarr.getOperationName() <<
": replace " << resVal <<
" with "
534 << currValTree->getStoredValue() <<
'\n'
536 replacementMap[resVal] = currValTree->getStoredValue();
538 state[resVal] = currValTree;
540 llvm::dbgs() << readarr.getOperationName() <<
": " << resVal <<
" => " << *currValTree
545 readVals.push_back(resVal);
554 auto doArrayWriteLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass writearr) {
555 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(writearr.getArrRef()));
556 Value newVal = translate(writearr.getRvalue());
557 std::shared_ptr<ReferenceNode> valTree = tryGetValTree(newVal);
559 for (Value origIdx : writearr.getIndices()) {
560 Value idxVal = translate(origIdx);
563 if (ReferenceID(idxVal).isValue()) {
564 LLVM_DEBUG(llvm::dbgs() << writearr.getOperationName() <<
": invalidate alias\n");
565 currValTree->invalidateChildren();
567 currValTree = currValTree->getOrCreateChild(idxVal);
570 if (currValTree->getStoredValue() == newVal) {
572 llvm::dbgs() << writearr.getOperationName() <<
": subsequent " << writearr
575 redundantWrites.push_back(writearr);
577 if (Operation *lastWrite = currValTree->updateLastWrite(writearr)) {
579 llvm::dbgs() << writearr.getOperationName() <<
"writearr: replacing " << lastWrite
580 <<
" with prior write " << *lastWrite <<
'\n'
582 redundantWrites.push_back(lastWrite);
584 currValTree->setCurrentValue(newVal, valTree);
589 if (
auto newStruct = dyn_cast<CreateStructOp>(op)) {
591 auto structVal = ReferenceNode::create(newStruct, newStruct);
592 state[newStruct] = structVal;
593 LLVM_DEBUG(llvm::dbgs() << newStruct.getOperationName() <<
": " << *state[newStruct] <<
'\n');
595 readVals.push_back(newStruct);
596 }
else if (
auto readf = dyn_cast<FieldReadOp>(op)) {
597 auto structVal = state.at(translate(readf.getComponent()));
598 FlatSymbolRefAttr symbol = readf.getFieldNameAttr();
599 Value resVal = translate(readf.getVal());
601 if (
auto child = structVal->getChild(symbol)) {
603 llvm::dbgs() << readf.getOperationName() <<
": adding replacement map entry { "
604 << resVal <<
" => " << child->getStoredValue() <<
" }\n"
606 replacementMap[resVal] = child->getStoredValue();
610 state[readf] = structVal->createChild(symbol, resVal);
611 LLVM_DEBUG(llvm::dbgs() << readf.getOperationName() <<
": " << *state[readf] <<
'\n');
614 readVals.push_back(readf.getVal());
615 }
else if (
auto writef = dyn_cast<FieldWriteOp>(op)) {
616 auto structVal = state.at(translate(writef.getComponent()));
617 Value writeVal = translate(writef.getVal());
618 FlatSymbolRefAttr symbol = writef.getFieldNameAttr();
619 auto valTree = tryGetValTree(writeVal);
621 auto child = structVal->getOrCreateChild(symbol);
622 if (child->getStoredValue() == writeVal) {
624 llvm::dbgs() << writef.getOperationName() <<
": recording redundant write " << writef
627 redundantWrites.push_back(writef);
629 if (
auto *lastWrite = child->updateLastWrite(writef)) {
631 llvm::dbgs() << writef.getOperationName() <<
": recording overwritten write "
632 << *lastWrite <<
'\n'
634 redundantWrites.push_back(lastWrite);
636 child->setCurrentValue(writeVal, valTree);
638 llvm::dbgs() << writef.getOperationName() <<
": " << *child <<
" set to " << writeVal
644 else if (
auto newArray = dyn_cast<CreateArrayOp>(op)) {
645 auto arrayVal = ReferenceNode::create(newArray, newArray);
646 state[newArray] = arrayVal;
651 for (
auto elem : newArray.getElements()) {
652 Value elemVal = translate(elem);
653 auto valTree = tryGetValTree(elemVal);
654 auto elemChild = arrayVal->createChild(idx, elemVal, valTree);
656 llvm::dbgs() << newArray.getOperationName() <<
": element " << idx <<
" initialized to "
657 << *elemChild <<
'\n'
662 readVals.push_back(newArray);
663 }
else if (
auto readarr = dyn_cast<ReadArrayOp>(op)) {
664 doArrayReadLike(readarr);
665 }
else if (
auto writearr = dyn_cast<WriteArrayOp>(op)) {
666 doArrayWriteLike(writearr);
667 }
else if (
auto extractarr = dyn_cast<ExtractArrayOp>(op)) {
669 doArrayReadLike(extractarr);
670 }
else if (
auto insertarr = dyn_cast<InsertArrayOp>(op)) {
672 doArrayWriteLike(insertarr);
680 return std::make_unique<RedundantReadAndWriteEliminationPass>();
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
void ensure(bool condition, llvm::Twine errMsg)
std::unique_ptr< mlir::Pass > createRedundantReadAndWriteEliminationPass()
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs)
static ReferenceID getEmptyKey()
static ReferenceID getTombstoneKey()
static unsigned getHashValue(const ReferenceID &r)