LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKRedundantReadAndWriteEliminationPass.cpp
Go to the documentation of this file.
1//===-- LLZKRedundantReadAndWriteEliminationPass.cpp ------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
19#include "llzk/Util/Concepts.h"
21
22#include <mlir/IR/BuiltinOps.h>
23
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
28
29#include <deque>
30#include <memory>
31
32// Include the generated base pass class definitions.
33namespace llzk {
34#define GEN_PASS_DEF_REDUNDANTREADANDWRITEELIMINATIONPASS
36} // namespace llzk
37
38using namespace mlir;
39using namespace llzk;
40using namespace llzk::array;
41using namespace llzk::felt;
42using namespace llzk::function;
43using namespace llzk::component;
44
45#define DEBUG_TYPE "llzk-redundant-read-write-pass"
46
47namespace {
48
51class ReferenceID {
52public:
53 explicit ReferenceID(Value v) {
54 // reserved special pointer values for DenseMapInfo
55 if (v.getImpl() == reinterpret_cast<mlir::detail::ValueImpl *>(1) ||
56 v.getImpl() == reinterpret_cast<mlir::detail::ValueImpl *>(2)) {
57 identifier = v;
58 } else if (auto constVal = dyn_cast_if_present<FeltConstantOp>(v.getDefiningOp())) {
59 identifier = constVal.getValue();
60 } else if (auto constIdxVal = dyn_cast_if_present<arith::ConstantIndexOp>(v.getDefiningOp())) {
61 identifier = llvm::cast<IntegerAttr>(constIdxVal.getValue()).getValue();
62 } else {
63 identifier = v;
64 }
65 }
66 explicit ReferenceID(FlatSymbolRefAttr s) : identifier(s) {}
67 explicit ReferenceID(const APInt &i) : identifier(i) {}
68 explicit ReferenceID(unsigned i) : identifier(APInt(64, i)) {}
69
70 bool isValue() const { return std::holds_alternative<Value>(identifier); }
71 bool isSymbol() const { return std::holds_alternative<FlatSymbolRefAttr>(identifier); }
72 bool isConst() const { return std::holds_alternative<APInt>(identifier); }
73
74 Value getValue() const {
75 ensure(isValue(), "does not hold Value");
76 return std::get<Value>(identifier);
77 }
78
79 FlatSymbolRefAttr getSymbol() const {
80 ensure(isSymbol(), "does not hold symbol");
81 return std::get<FlatSymbolRefAttr>(identifier);
82 }
83
84 APInt getConst() const {
85 ensure(isConst(), "does not hold const");
86 return std::get<APInt>(identifier);
87 }
88
89 void print(raw_ostream &os) const {
90 if (auto v = std::get_if<Value>(&identifier)) {
91 if (auto opres = dyn_cast<OpResult>(*v)) {
92 os << '%' << opres.getResultNumber();
93 } else {
94 os << *v;
95 }
96 } else if (auto s = std::get_if<FlatSymbolRefAttr>(&identifier)) {
97 os << *s;
98 } else {
99 os << std::get<APInt>(identifier);
100 }
101 }
102
103 friend bool operator==(const ReferenceID &lhs, const ReferenceID &rhs) {
104 return lhs.identifier == rhs.identifier;
105 }
106
107 friend raw_ostream &operator<<(raw_ostream &os, const ReferenceID &id) {
108 id.print(os);
109 return os;
110 }
111
112private:
117 std::variant<FlatSymbolRefAttr, APInt, Value> identifier;
118};
119
120} // namespace
121
122namespace llvm {
123
125template <> struct DenseMapInfo<ReferenceID> {
126 static ReferenceID getEmptyKey() {
127 return ReferenceID(mlir::Value(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
128 }
129 static inline ReferenceID getTombstoneKey() {
130 return ReferenceID(mlir::Value(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
131 }
132 static unsigned getHashValue(const ReferenceID &r) {
133 if (r.isValue()) {
134 return hash_value(r.getValue());
135 } else if (r.isSymbol()) {
136 return hash_value(r.getSymbol());
137 }
138 return hash_value(r.getConst());
139 }
140 static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs) { return lhs == rhs; }
141};
142
143} // namespace llvm
144
145namespace {
146
165class ReferenceNode {
166public:
167 template <typename IdType> static std::shared_ptr<ReferenceNode> create(IdType id, Value v) {
168 ReferenceNode n(id, v);
169 // Need the move constructor version since constructor is private
170 return std::make_shared<ReferenceNode>(std::move(n));
171 }
172
175 std::shared_ptr<ReferenceNode> clone(bool withChildren = true) const {
176 ReferenceNode copy(identifier, storedValue);
177 copy.updateLastWrite(lastWrite);
178 if (withChildren) {
179 for (const auto &[id, child] : children) {
180 copy.children[id] = child->clone(withChildren);
181 }
182 }
183 return std::make_shared<ReferenceNode>(std::move(copy));
184 }
185
186 template <typename IdType>
187 std::shared_ptr<ReferenceNode>
188 createChild(IdType id, Value storedVal, std::shared_ptr<ReferenceNode> valTree = nullptr) {
189 std::shared_ptr<ReferenceNode> child = create(id, storedVal);
190 child->setCurrentValue(storedVal, valTree);
191 children[child->identifier] = child;
192 return child;
193 }
194
197 template <typename IdType> std::shared_ptr<ReferenceNode> getChild(IdType id) const {
198 auto it = children.find(ReferenceID(id));
199 if (it != children.end()) {
200 return it->second;
201 }
202 return nullptr;
203 }
204
208 template <typename IdType>
209 std::shared_ptr<ReferenceNode> getOrCreateChild(IdType id, Value storedVal = nullptr) {
210 auto it = children.find(ReferenceID(id));
211 if (it != children.end()) {
212 return it->second;
213 }
214 return createChild(id, storedVal);
215 }
216
219 Operation *updateLastWrite(Operation *writeOp) {
220 Operation *old = lastWrite;
221 lastWrite = writeOp;
222 return old;
223 }
224
225 void setCurrentValue(Value v, std::shared_ptr<ReferenceNode> valTree = nullptr) {
226 storedValue = v;
227 if (valTree != nullptr) {
228 // Overwrite our current set of children with new children, since we overwrote
229 // the stored value.
230 children = valTree->children;
231 }
232 }
233
234 void invalidateChildren() { children.clear(); }
235
236 bool isLeaf() const { return children.empty(); }
237
238 Value getStoredValue() const { return storedValue; }
239
240 bool hasStoredValue() const { return storedValue != nullptr; }
241
242 void print(raw_ostream &os, int indent = 0) const {
243 os.indent(indent) << '[' << identifier;
244 if (storedValue != nullptr) {
245 os << " => " << storedValue;
246 }
247 os << ']';
248 if (!children.empty()) {
249 os << "{\n";
250 for (auto &[_, child] : children) {
251 child->print(os, indent + 4);
252 os << '\n';
253 }
254 os.indent(indent) << '}';
255 }
256 }
257
258 [[maybe_unused]]
259 friend raw_ostream &operator<<(raw_ostream &os, const ReferenceNode &r) {
260 r.print(os);
261 return os;
262 }
263
265 friend bool
266 topLevelEq(const std::shared_ptr<ReferenceNode> &lhs, const std::shared_ptr<ReferenceNode> &rhs) {
267 return lhs->identifier == rhs->identifier && lhs->storedValue == rhs->storedValue &&
268 lhs->lastWrite == rhs->lastWrite;
269 }
270
271 friend std::shared_ptr<ReferenceNode> greatestCommonSubtree(
272 const std::shared_ptr<ReferenceNode> &lhs, const std::shared_ptr<ReferenceNode> &rhs
273 ) {
274 if (!topLevelEq(lhs, rhs)) {
275 return nullptr;
276 }
277 auto res = lhs->clone(false); // childless clone
278 // Find common children and recurse
279 for (auto &[id, lhsChild] : lhs->children) {
280 if (auto it = rhs->children.find(id); it != rhs->children.end()) {
281 auto &rhsChild = it->second;
282 if (auto gcs = greatestCommonSubtree(lhsChild, rhsChild)) {
283 res->children[id] = gcs;
284 }
285 }
286 }
287 return res;
288 }
289
290private:
291 ReferenceID identifier;
292 mlir::Value storedValue;
293 Operation *lastWrite;
294 DenseMap<ReferenceID, std::shared_ptr<ReferenceNode>> children;
295
296 template <typename IdType>
297 ReferenceNode(IdType id, Value initialVal)
298 : identifier(id), storedValue(initialVal), lastWrite(nullptr), children() {}
299};
300
301using ValueMap = DenseMap<mlir::Value, std::shared_ptr<ReferenceNode>>;
302
303ValueMap intersect(const ValueMap &lhs, const ValueMap &rhs) {
304 ValueMap res;
305 for (auto &[id, lhsValTree] : lhs) {
306 if (auto it = rhs.find(id); it != rhs.end()) {
307 auto &rhsValTree = it->second;
308 res[id] = greatestCommonSubtree(lhsValTree, rhsValTree);
309 }
310 }
311 return res;
312}
313
316ValueMap cloneValueMap(const ValueMap &orig) {
317 ValueMap res;
318 for (auto &[id, tree] : orig) {
319 res[id] = tree->clone();
320 }
321 return res;
322}
323
324class RedundantReadAndWriteEliminationPass
326 RedundantReadAndWriteEliminationPass> {
332 void runOnOperation() override {
333 getOperation().walk([&](FuncDefOp fn) { runOnFunc(fn); });
334 }
335
338 void runOnFunc(FuncDefOp fn) {
339 // Nothing to do for body-less functions.
340 if (fn.getCallableRegion() == nullptr) {
341 return;
342 }
343
344 LLVM_DEBUG(llvm::dbgs() << "Running on " << fn.getName() << '\n');
345
346 // Maps redundant value -> necessary value.
347 DenseMap<Value, Value> replacementMap;
348 // All values created by a new_* operation or from a read*/extract* operation.
349 SmallVector<Value> readVals;
350 // All writes that are either (1) overwritten by subsequent writes or (2)
351 // write a value that is already written.
352 SmallVector<Operation *> redundantWrites;
353
354 ValueMap initState;
355 // Initialize the state to the function arguments.
356 for (auto arg : fn.getArguments()) {
357 initState[arg] = ReferenceNode::create(arg, arg);
358 }
359 // Functions only have a single region
360 (void)runOnRegion(
361 *fn.getCallableRegion(), std::move(initState), replacementMap, readVals, redundantWrites
362 );
363
364 // Now that we have accumulated all necessary state, we perform the optimizations:
365 // - Replace all redundant values.
366 for (auto &[orig, replace] : replacementMap) {
367 LLVM_DEBUG(llvm::dbgs() << "replacing " << orig << " with " << orig << '\n');
368 orig.replaceAllUsesWith(replace);
369 // We save the deletion to the readVals loop to prevent double-free.
370 }
371 // -Remove redundant writes now that it is safe to do so.
372 for (auto *writeOp : redundantWrites) {
373 LLVM_DEBUG(llvm::dbgs() << "erase write: " << *writeOp << '\n');
374 writeOp->erase();
375 }
376 // - Now we do a pass over read values to see if any are now unused.
377 // We do this in reverse order to free up early reads if their users would
378 // be removed.
379 for (auto it = readVals.rbegin(); it != readVals.rend(); it++) {
380 Value readVal = *it;
381 if (readVal.use_empty()) {
382 LLVM_DEBUG(llvm::dbgs() << "erase read: " << readVal << '\n');
383 readVal.getDefiningOp()->erase();
384 }
385 }
386 }
387
388 ValueMap runOnRegion(
389 Region &r, ValueMap &&initState, DenseMap<Value, Value> &replacementMap,
390 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
391 ) {
392 // maps block -> state at the end of the block
393 DenseMap<Block *, ValueMap> endStates;
394 // The first block has no predecessors, so nullptr contains the init state
395 endStates[nullptr] = initState;
396 auto getBlockState = [&endStates](Block *blockPtr) {
397 auto it = endStates.find(blockPtr);
398 ensure(it != endStates.end(), "unknown end state means we have an unsupported backedge");
399 return cloneValueMap(it->second);
400 };
401 std::deque<Block *> frontier;
402 frontier.push_back(&r.front());
403 DenseSet<Block *> visited;
404
405 SmallVector<std::reference_wrapper<const ValueMap>> terminalStates;
406
407 while (!frontier.empty()) {
408 Block *currentBlock = frontier.front();
409 frontier.pop_front();
410 visited.insert(currentBlock);
411
412 // get predecessors
413 ValueMap currentState;
414 auto it = currentBlock->pred_begin();
415 auto itEnd = currentBlock->pred_end();
416 if (it == itEnd) {
417 // get the state for the entry block.
418 currentState = getBlockState(nullptr);
419 } else {
420 currentState = getBlockState(*it);
421 // If we have multiple predecessors, we take a pessimistic view and
422 // set the state as only the intersection of all predecessor states
423 // (e.g., only the common state from an if branch).
424 for (it++; it != itEnd; it++) {
425 currentState = intersect(currentState, getBlockState(*it));
426 }
427 }
428
429 // Run this block, consuming currentState and producing the endState
430 auto endState = runOnBlock(
431 *currentBlock, std::move(currentState), replacementMap, readVals, redundantWrites
432 );
433
434 // Update the end states.
435 // Since we only support the scf dialect, we should never have any
436 // backedges, so we should never already have state for this block.
437 ensure(endStates.find(currentBlock) == endStates.end(), "backedge");
438 endStates[currentBlock] = std::move(endState);
439
440 // add successors to frontier
441 if (currentBlock->hasNoSuccessors()) {
442 terminalStates.push_back(endStates[currentBlock]);
443 } else {
444 for (Block *succ : currentBlock->getSuccessors()) {
445 if (visited.find(succ) == visited.end()) {
446 frontier.push_back(succ);
447 }
448 }
449 }
450 }
451
452 // The final state is the intersection of all possible terminal states.
453 ensure(!terminalStates.empty(), "computed no states");
454 auto finalState = terminalStates.front().get();
455 for (auto it = terminalStates.begin() + 1; it != terminalStates.end(); it++) {
456 finalState = intersect(finalState, it->get());
457 }
458 return finalState;
459 }
460
461 ValueMap runOnBlock(
462 Block &b, ValueMap &&state, DenseMap<Value, Value> &replacementMap,
463 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
464 ) {
465 for (Operation &op : b) {
466 runOperation(&op, state, replacementMap, readVals, redundantWrites);
467 // Some operations have regions (e.g., scf.if). These regions must be
468 // traversed and the resulting state(s) are intersected for the final
469 // state of this operation.
470 if (!op.getRegions().empty()) {
471 SmallVector<ValueMap> regionStates;
472 for (Region &region : op.getRegions()) {
473 auto regionState =
474 runOnRegion(region, cloneValueMap(state), replacementMap, readVals, redundantWrites);
475 regionStates.push_back(regionState);
476 }
477
478 ValueMap finalState = regionStates.front();
479 for (auto it = regionStates.begin() + 1; it != regionStates.end(); it++) {
480 finalState = intersect(finalState, *it);
481 }
482 state = std::move(finalState);
483 }
484 }
485 return std::move(state);
486 }
487
495 void runOperation(
496 Operation *op, ValueMap &state, DenseMap<Value, Value> &replacementMap,
497 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
498 ) {
499 // Uses the replacement map to look up values to simplify later replacement.
500 // This avoids having a daisy chain of "replace B with A", "replace C with B",
501 // etc.
502 auto translate = [&replacementMap](Value v) {
503 if (auto it = replacementMap.find(v); it != replacementMap.end()) {
504 return it->second;
506 return v;
507 };
508
509 // Lookup the value tree in the current state or return nullptr.
510 auto tryGetValTree = [&state](Value v) -> std::shared_ptr<ReferenceNode> {
511 if (auto it = state.find(v); it != state.end()) {
512 return it->second;
514 return nullptr;
515 };
516
517 // Read a value from an array. This works on both readarr operations (which
518 // return a scalar value) and extractarr operations (which return a subarray).
519 auto doArrayReadLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass readarr) {
520 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(readarr.getArrRef()));
521
522 for (Value origIdx : readarr.getIndices()) {
523 Value idxVal = translate(origIdx);
524 currValTree = currValTree->getOrCreateChild(idxVal);
526
527 Value resVal = readarr.getResult();
528 if (!currValTree->hasStoredValue()) {
529 currValTree->setCurrentValue(resVal);
530 }
532 if (currValTree->getStoredValue() != resVal) {
533 LLVM_DEBUG(
534 llvm::dbgs() << readarr.getOperationName() << ": replace " << resVal << " with "
535 << currValTree->getStoredValue() << '\n'
536 );
537 replacementMap[resVal] = currValTree->getStoredValue();
538 } else {
539 state[resVal] = currValTree;
540 LLVM_DEBUG(
541 llvm::dbgs() << readarr.getOperationName() << ": " << resVal << " => " << *currValTree
542 << '\n'
543 );
544 }
545
546 readVals.push_back(resVal);
547 };
548
549 // Write a scalar value (for writearr) or a subarray value (for insertarr)
550 // to an array. The unique part of this operation relative to others is that
551 // we may receive a variable index (i.e., not a constant). In this case, we
552 // invalidate ajoining parts of the subtree, since it is possible that
553 // the variable index aliases one of the other elements and may or may not
554 // override that value.
555 auto doArrayWriteLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass writearr) {
556 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(writearr.getArrRef()));
557 Value newVal = translate(writearr.getRvalue());
558 std::shared_ptr<ReferenceNode> valTree = tryGetValTree(newVal);
559
560 for (Value origIdx : writearr.getIndices()) {
561 Value idxVal = translate(origIdx);
562 // This write will invalidate all children, since it may reference
563 // any number of them.
564 if (ReferenceID(idxVal).isValue()) {
565 LLVM_DEBUG(llvm::dbgs() << writearr.getOperationName() << ": invalidate alias\n");
566 currValTree->invalidateChildren();
567 }
568 currValTree = currValTree->getOrCreateChild(idxVal);
569 }
570
571 if (currValTree->getStoredValue() == newVal) {
572 LLVM_DEBUG(
573 llvm::dbgs() << writearr.getOperationName() << ": subsequent " << writearr
574 << " is redundant\n"
575 );
576 redundantWrites.push_back(writearr);
577 } else {
578 if (Operation *lastWrite = currValTree->updateLastWrite(writearr)) {
579 LLVM_DEBUG(
580 llvm::dbgs() << writearr.getOperationName() << "writearr: replacing " << lastWrite
581 << " with prior write " << *lastWrite << '\n'
582 );
583 redundantWrites.push_back(lastWrite);
584 }
585 currValTree->setCurrentValue(newVal, valTree);
586 }
587 };
588
589 // struct ops
590 if (auto newStruct = dyn_cast<CreateStructOp>(op)) {
591 // For new values, the "stored value" of the reference is the creation site.
592 auto structVal = ReferenceNode::create(newStruct, newStruct);
593 state[newStruct] = structVal;
594 LLVM_DEBUG(llvm::dbgs() << newStruct.getOperationName() << ": " << *state[newStruct] << '\n');
595 // adding this to readVals
596 readVals.push_back(newStruct);
597 } else if (auto readf = dyn_cast<FieldReadOp>(op)) {
598 auto structVal = state.at(translate(readf.getComponent()));
599 FlatSymbolRefAttr symbol = readf.getFieldNameAttr();
600 Value resVal = translate(readf.getVal());
601 // Check if such a child already exists.
602 if (auto child = structVal->getChild(symbol)) {
603 LLVM_DEBUG(
604 llvm::dbgs() << readf.getOperationName() << ": adding replacement map entry { "
605 << resVal << " => " << child->getStoredValue() << " }\n"
606 );
607 replacementMap[resVal] = child->getStoredValue();
608 } else {
609 // If we have no previous store, we create a new symbolic value for
610 // this location.
611 state[readf] = structVal->createChild(symbol, resVal);
612 LLVM_DEBUG(llvm::dbgs() << readf.getOperationName() << ": " << *state[readf] << '\n');
613 }
614 // specifically add the untranslated value back for removal checks
615 readVals.push_back(readf.getVal());
616 } else if (auto writef = dyn_cast<FieldWriteOp>(op)) {
617 auto structVal = state.at(translate(writef.getComponent()));
618 Value writeVal = translate(writef.getVal());
619 FlatSymbolRefAttr symbol = writef.getFieldNameAttr();
620 auto valTree = tryGetValTree(writeVal);
621
622 auto child = structVal->getOrCreateChild(symbol);
623 if (child->getStoredValue() == writeVal) {
624 LLVM_DEBUG(
625 llvm::dbgs() << writef.getOperationName() << ": recording redundant write " << writef
626 << '\n'
627 );
628 redundantWrites.push_back(writef);
629 } else {
630 if (auto *lastWrite = child->updateLastWrite(writef)) {
631 LLVM_DEBUG(
632 llvm::dbgs() << writef.getOperationName() << ": recording overwritten write "
633 << *lastWrite << '\n'
634 );
635 redundantWrites.push_back(lastWrite);
636 }
637 child->setCurrentValue(writeVal, valTree);
638 LLVM_DEBUG(
639 llvm::dbgs() << writef.getOperationName() << ": " << *child << " set to " << writeVal
640 << '\n'
641 );
642 }
643 }
644 // array ops
645 else if (auto newArray = dyn_cast<CreateArrayOp>(op)) {
646 auto arrayVal = ReferenceNode::create(newArray, newArray);
647 state[newArray] = arrayVal;
648
649 // If we're given a constructor, we can instantiate elements using
650 // constant indices.
651 unsigned idx = 0;
652 for (auto elem : newArray.getElements()) {
653 Value elemVal = translate(elem);
654 auto valTree = tryGetValTree(elemVal);
655 auto elemChild = arrayVal->createChild(idx, elemVal, valTree);
656 LLVM_DEBUG(
657 llvm::dbgs() << newArray.getOperationName() << ": element " << idx << " initialized to "
658 << *elemChild << '\n'
659 );
660 idx++;
661 }
662
663 readVals.push_back(newArray);
664 } else if (auto readarr = dyn_cast<ReadArrayOp>(op)) {
665 doArrayReadLike(readarr);
666 } else if (auto writearr = dyn_cast<WriteArrayOp>(op)) {
667 doArrayWriteLike(writearr);
668 } else if (auto extractarr = dyn_cast<ExtractArrayOp>(op)) {
669 // Logic is essentially the same as readarr
670 doArrayReadLike(extractarr);
671 } else if (auto insertarr = dyn_cast<InsertArrayOp>(op)) {
672 // Logic is essentially the same as writearr
673 doArrayWriteLike(insertarr);
674 }
675 }
676};
677
678} // namespace
679
681 return std::make_unique<RedundantReadAndWriteEliminationPass>();
682};
void print(llvm::raw_ostream &os) const
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
Definition Ops.h.inc:746
Restricts a template parameter to Op classes that implement the given OpInterface.
Definition Concepts.h:20
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::unique_ptr< mlir::Pass > createRedundantReadAndWriteEliminationPass()
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs)