22#include <mlir/IR/BuiltinOps.h>
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
34#define GEN_PASS_DEF_REDUNDANTREADANDWRITEELIMINATIONPASS
45#define DEBUG_TYPE "llzk-redundant-read-write-pass"
53 explicit ReferenceID(Value v) {
55 if (v.getImpl() ==
reinterpret_cast<mlir::detail::ValueImpl *
>(1) ||
56 v.getImpl() ==
reinterpret_cast<mlir::detail::ValueImpl *
>(2)) {
58 }
else if (
auto constVal = dyn_cast_if_present<FeltConstantOp>(v.getDefiningOp())) {
59 identifier = constVal.getValue();
60 }
else if (
auto constIdxVal = dyn_cast_if_present<arith::ConstantIndexOp>(v.getDefiningOp())) {
61 identifier = llvm::cast<IntegerAttr>(constIdxVal.getValue()).getValue();
66 explicit ReferenceID(FlatSymbolRefAttr s) : identifier(s) {}
67 explicit ReferenceID(
const APInt &i) : identifier(i) {}
68 explicit ReferenceID(
unsigned i) : identifier(APInt(64, i)) {}
70 bool isValue()
const {
return std::holds_alternative<Value>(identifier); }
71 bool isSymbol()
const {
return std::holds_alternative<FlatSymbolRefAttr>(identifier); }
72 bool isConst()
const {
return std::holds_alternative<APInt>(identifier); }
74 Value getValue()
const {
75 ensure(isValue(),
"does not hold Value");
76 return std::get<Value>(identifier);
79 FlatSymbolRefAttr getSymbol()
const {
80 ensure(isSymbol(),
"does not hold symbol");
81 return std::get<FlatSymbolRefAttr>(identifier);
84 APInt getConst()
const {
85 ensure(isConst(),
"does not hold const");
86 return std::get<APInt>(identifier);
89 void print(raw_ostream &os)
const {
90 if (
auto v = std::get_if<Value>(&identifier)) {
91 if (
auto opres = dyn_cast<OpResult>(*v)) {
92 os <<
'%' << opres.getResultNumber();
96 }
else if (
auto s = std::get_if<FlatSymbolRefAttr>(&identifier)) {
99 os << std::get<APInt>(identifier);
103 friend bool operator==(
const ReferenceID &lhs,
const ReferenceID &rhs) {
104 return lhs.identifier == rhs.identifier;
107 friend raw_ostream &
operator<<(raw_ostream &os,
const ReferenceID &
id) {
117 std::variant<FlatSymbolRefAttr, APInt, Value> identifier;
125template <>
struct DenseMapInfo<ReferenceID> {
127 return ReferenceID(mlir::Value(
reinterpret_cast<mlir::detail::ValueImpl *
>(1)));
130 return ReferenceID(mlir::Value(
reinterpret_cast<mlir::detail::ValueImpl *
>(2)));
134 return hash_value(r.getValue());
135 }
else if (r.isSymbol()) {
136 return hash_value(r.getSymbol());
138 return hash_value(r.getConst());
140 static bool isEqual(
const ReferenceID &lhs,
const ReferenceID &rhs) {
return lhs == rhs; }
167 template <
typename IdType>
static std::shared_ptr<ReferenceNode>
create(IdType
id, Value v) {
168 ReferenceNode n(
id, v);
170 return std::make_shared<ReferenceNode>(std::move(n));
175 std::shared_ptr<ReferenceNode> clone(
bool withChildren =
true)
const {
176 ReferenceNode copy(identifier, storedValue);
177 copy.updateLastWrite(lastWrite);
179 for (
const auto &[
id, child] : children) {
180 copy.children[id] = child->clone(withChildren);
183 return std::make_shared<ReferenceNode>(std::move(copy));
186 template <
typename IdType>
187 std::shared_ptr<ReferenceNode>
188 createChild(IdType
id, Value storedVal, std::shared_ptr<ReferenceNode> valTree =
nullptr) {
189 std::shared_ptr<ReferenceNode> child =
create(
id, storedVal);
190 child->setCurrentValue(storedVal, valTree);
191 children[child->identifier] = child;
197 template <
typename IdType> std::shared_ptr<ReferenceNode> getChild(IdType
id)
const {
198 auto it = children.find(ReferenceID(
id));
199 if (it != children.end()) {
208 template <
typename IdType>
209 std::shared_ptr<ReferenceNode> getOrCreateChild(IdType
id, Value storedVal =
nullptr) {
210 auto it = children.find(ReferenceID(
id));
211 if (it != children.end()) {
214 return createChild(
id, storedVal);
219 Operation *updateLastWrite(Operation *writeOp) {
220 Operation *old = lastWrite;
225 void setCurrentValue(Value v, std::shared_ptr<ReferenceNode> valTree =
nullptr) {
227 if (valTree !=
nullptr) {
230 children = valTree->children;
234 void invalidateChildren() { children.clear(); }
236 bool isLeaf()
const {
return children.empty(); }
238 Value getStoredValue()
const {
return storedValue; }
240 bool hasStoredValue()
const {
return storedValue !=
nullptr; }
242 void print(raw_ostream &os,
int indent = 0)
const {
243 os.indent(indent) <<
'[' << identifier;
244 if (storedValue !=
nullptr) {
245 os <<
" => " << storedValue;
248 if (!children.empty()) {
250 for (
auto &[_, child] : children) {
251 child->print(os, indent + 4);
254 os.indent(indent) <<
'}';
259 friend raw_ostream &
operator<<(raw_ostream &os,
const ReferenceNode &r) {
266 topLevelEq(
const std::shared_ptr<ReferenceNode> &lhs,
const std::shared_ptr<ReferenceNode> &rhs) {
267 return lhs->identifier == rhs->identifier && lhs->storedValue == rhs->storedValue &&
268 lhs->lastWrite == rhs->lastWrite;
271 friend std::shared_ptr<ReferenceNode> greatestCommonSubtree(
272 const std::shared_ptr<ReferenceNode> &lhs,
const std::shared_ptr<ReferenceNode> &rhs
274 if (!topLevelEq(lhs, rhs)) {
277 auto res = lhs->clone(
false);
279 for (
auto &[
id, lhsChild] : lhs->children) {
280 if (
auto it = rhs->children.find(
id); it != rhs->children.end()) {
281 auto &rhsChild = it->second;
282 if (
auto gcs = greatestCommonSubtree(lhsChild, rhsChild)) {
283 res->children[id] = gcs;
291 ReferenceID identifier;
292 mlir::Value storedValue;
293 Operation *lastWrite;
294 DenseMap<ReferenceID, std::shared_ptr<ReferenceNode>> children;
296 template <
typename IdType>
297 ReferenceNode(IdType
id, Value initialVal)
298 : identifier(id), storedValue(initialVal), lastWrite(nullptr), children() {}
301using ValueMap = DenseMap<mlir::Value, std::shared_ptr<ReferenceNode>>;
303ValueMap intersect(
const ValueMap &lhs,
const ValueMap &rhs) {
305 for (
auto &[
id, lhsValTree] : lhs) {
306 if (
auto it = rhs.find(
id); it != rhs.end()) {
307 auto &rhsValTree = it->second;
308 res[id] = greatestCommonSubtree(lhsValTree, rhsValTree);
316ValueMap cloneValueMap(
const ValueMap &orig) {
318 for (
auto &[
id, tree] : orig) {
319 res[id] = tree->clone();
324class RedundantReadAndWriteEliminationPass
326 RedundantReadAndWriteEliminationPass> {
332 void runOnOperation()
override {
333 getOperation().walk([&](FuncDefOp fn) { runOnFunc(fn); });
338 void runOnFunc(FuncDefOp fn) {
344 LLVM_DEBUG(llvm::dbgs() <<
"Running on " << fn.getName() <<
'\n');
347 DenseMap<Value, Value> replacementMap;
349 SmallVector<Value> readVals;
352 SmallVector<Operation *> redundantWrites;
356 for (
auto arg : fn.getArguments()) {
357 initState[arg] = ReferenceNode::create(arg, arg);
361 *fn.
getCallableRegion(), std::move(initState), replacementMap, readVals, redundantWrites
366 for (
auto &[orig, replace] : replacementMap) {
367 LLVM_DEBUG(llvm::dbgs() <<
"replacing " << orig <<
" with " << orig <<
'\n');
368 orig.replaceAllUsesWith(replace);
372 for (
auto *writeOp : redundantWrites) {
373 LLVM_DEBUG(llvm::dbgs() <<
"erase write: " << *writeOp <<
'\n');
379 for (
auto it = readVals.rbegin(); it != readVals.rend(); it++) {
381 if (readVal.use_empty()) {
382 LLVM_DEBUG(llvm::dbgs() <<
"erase read: " << readVal <<
'\n');
383 readVal.getDefiningOp()->erase();
388 ValueMap runOnRegion(
389 Region &r, ValueMap &&initState, DenseMap<Value, Value> &replacementMap,
390 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
393 DenseMap<Block *, ValueMap> endStates;
395 endStates[
nullptr] = initState;
396 auto getBlockState = [&endStates](Block *blockPtr) {
397 auto it = endStates.find(blockPtr);
398 ensure(it != endStates.end(),
"unknown end state means we have an unsupported backedge");
399 return cloneValueMap(it->second);
401 std::deque<Block *> frontier;
402 frontier.push_back(&r.front());
403 DenseSet<Block *> visited;
405 SmallVector<std::reference_wrapper<const ValueMap>> terminalStates;
407 while (!frontier.empty()) {
408 Block *currentBlock = frontier.front();
409 frontier.pop_front();
410 visited.insert(currentBlock);
413 ValueMap currentState;
414 auto it = currentBlock->pred_begin();
415 auto itEnd = currentBlock->pred_end();
418 currentState = getBlockState(
nullptr);
420 currentState = getBlockState(*it);
424 for (it++; it != itEnd; it++) {
425 currentState = intersect(currentState, getBlockState(*it));
430 auto endState = runOnBlock(
431 *currentBlock, std::move(currentState), replacementMap, readVals, redundantWrites
437 ensure(endStates.find(currentBlock) == endStates.end(),
"backedge");
438 endStates[currentBlock] = std::move(endState);
441 if (currentBlock->hasNoSuccessors()) {
442 terminalStates.push_back(endStates[currentBlock]);
444 for (Block *succ : currentBlock->getSuccessors()) {
445 if (visited.find(succ) == visited.end()) {
446 frontier.push_back(succ);
453 ensure(!terminalStates.empty(),
"computed no states");
454 auto finalState = terminalStates.front().get();
455 for (
auto it = terminalStates.begin() + 1; it != terminalStates.end(); it++) {
456 finalState = intersect(finalState, it->get());
462 Block &b, ValueMap &&state, DenseMap<Value, Value> &replacementMap,
463 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
465 for (Operation &op : b) {
466 runOperation(&op, state, replacementMap, readVals, redundantWrites);
470 if (!op.getRegions().empty()) {
471 SmallVector<ValueMap> regionStates;
472 for (Region ®ion : op.getRegions()) {
474 runOnRegion(region, cloneValueMap(state), replacementMap, readVals, redundantWrites);
475 regionStates.push_back(regionState);
478 ValueMap finalState = regionStates.front();
479 for (
auto it = regionStates.begin() + 1; it != regionStates.end(); it++) {
480 finalState = intersect(finalState, *it);
482 state = std::move(finalState);
485 return std::move(state);
496 Operation *op, ValueMap &state, DenseMap<Value, Value> &replacementMap,
497 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
502 auto translate = [&replacementMap](Value v) {
503 if (
auto it = replacementMap.find(v); it != replacementMap.end()) {
510 auto tryGetValTree = [&state](Value v) -> std::shared_ptr<ReferenceNode> {
511 if (
auto it = state.find(v); it != state.end()) {
520 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(readarr.getArrRef()));
522 for (Value origIdx : readarr.getIndices()) {
523 Value idxVal = translate(origIdx);
524 currValTree = currValTree->getOrCreateChild(idxVal);
527 Value resVal = readarr.getResult();
528 if (!currValTree->hasStoredValue()) {
529 currValTree->setCurrentValue(resVal);
532 if (currValTree->getStoredValue() != resVal) {
534 llvm::dbgs() << readarr.getOperationName() <<
": replace " << resVal <<
" with "
535 << currValTree->getStoredValue() <<
'\n'
537 replacementMap[resVal] = currValTree->getStoredValue();
539 state[resVal] = currValTree;
541 llvm::dbgs() << readarr.getOperationName() <<
": " << resVal <<
" => " << *currValTree
546 readVals.push_back(resVal);
556 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(writearr.getArrRef()));
557 Value newVal = translate(writearr.getRvalue());
558 std::shared_ptr<ReferenceNode> valTree = tryGetValTree(newVal);
560 for (Value origIdx : writearr.getIndices()) {
561 Value idxVal = translate(origIdx);
564 if (ReferenceID(idxVal).isValue()) {
565 LLVM_DEBUG(llvm::dbgs() << writearr.getOperationName() <<
": invalidate alias\n");
566 currValTree->invalidateChildren();
568 currValTree = currValTree->getOrCreateChild(idxVal);
571 if (currValTree->getStoredValue() == newVal) {
573 llvm::dbgs() << writearr.getOperationName() <<
": subsequent " << writearr
576 redundantWrites.push_back(writearr);
578 if (Operation *lastWrite = currValTree->updateLastWrite(writearr)) {
580 llvm::dbgs() << writearr.getOperationName() <<
"writearr: replacing " << lastWrite
581 <<
" with prior write " << *lastWrite <<
'\n'
583 redundantWrites.push_back(lastWrite);
585 currValTree->setCurrentValue(newVal, valTree);
590 if (
auto newStruct = dyn_cast<CreateStructOp>(op)) {
592 auto structVal = ReferenceNode::create(newStruct, newStruct);
593 state[newStruct] = structVal;
594 LLVM_DEBUG(llvm::dbgs() << newStruct.getOperationName() <<
": " << *state[newStruct] <<
'\n');
596 readVals.push_back(newStruct);
597 }
else if (
auto readf = dyn_cast<FieldReadOp>(op)) {
598 auto structVal = state.at(translate(readf.getComponent()));
599 FlatSymbolRefAttr symbol = readf.getFieldNameAttr();
600 Value resVal = translate(readf.getVal());
602 if (
auto child = structVal->getChild(symbol)) {
604 llvm::dbgs() << readf.getOperationName() <<
": adding replacement map entry { "
605 << resVal <<
" => " << child->getStoredValue() <<
" }\n"
607 replacementMap[resVal] = child->getStoredValue();
611 state[readf] = structVal->createChild(symbol, resVal);
612 LLVM_DEBUG(llvm::dbgs() << readf.getOperationName() <<
": " << *state[readf] <<
'\n');
615 readVals.push_back(readf.getVal());
616 }
else if (
auto writef = dyn_cast<FieldWriteOp>(op)) {
617 auto structVal = state.at(translate(writef.getComponent()));
618 Value writeVal = translate(writef.getVal());
619 FlatSymbolRefAttr symbol = writef.getFieldNameAttr();
620 auto valTree = tryGetValTree(writeVal);
622 auto child = structVal->getOrCreateChild(symbol);
623 if (child->getStoredValue() == writeVal) {
625 llvm::dbgs() << writef.getOperationName() <<
": recording redundant write " << writef
628 redundantWrites.push_back(writef);
630 if (
auto *lastWrite = child->updateLastWrite(writef)) {
632 llvm::dbgs() << writef.getOperationName() <<
": recording overwritten write "
633 << *lastWrite <<
'\n'
635 redundantWrites.push_back(lastWrite);
637 child->setCurrentValue(writeVal, valTree);
639 llvm::dbgs() << writef.getOperationName() <<
": " << *child <<
" set to " << writeVal
645 else if (
auto newArray = dyn_cast<CreateArrayOp>(op)) {
646 auto arrayVal = ReferenceNode::create(newArray, newArray);
647 state[newArray] = arrayVal;
652 for (
auto elem : newArray.getElements()) {
653 Value elemVal = translate(elem);
654 auto valTree = tryGetValTree(elemVal);
655 auto elemChild = arrayVal->createChild(idx, elemVal, valTree);
657 llvm::dbgs() << newArray.getOperationName() <<
": element " << idx <<
" initialized to "
658 << *elemChild <<
'\n'
663 readVals.push_back(newArray);
664 }
else if (
auto readarr = dyn_cast<ReadArrayOp>(op)) {
665 doArrayReadLike(readarr);
666 }
else if (
auto writearr = dyn_cast<WriteArrayOp>(op)) {
667 doArrayWriteLike(writearr);
668 }
else if (
auto extractarr = dyn_cast<ExtractArrayOp>(op)) {
670 doArrayReadLike(extractarr);
671 }
else if (
auto insertarr = dyn_cast<InsertArrayOp>(op)) {
673 doArrayWriteLike(insertarr);
681 return std::make_unique<RedundantReadAndWriteEliminationPass>();
void print(llvm::raw_ostream &os) const
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
Restricts a template parameter to Op classes that implement the given OpInterface.
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::unique_ptr< mlir::Pass > createRedundantReadAndWriteEliminationPass()
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs)
static ReferenceID getEmptyKey()
static ReferenceID getTombstoneKey()
static unsigned getHashValue(const ReferenceID &r)