LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKRedundantReadAndWriteEliminationPass.cpp
Go to the documentation of this file.
1//===-- LLZKRedundantReadAndWriteEliminationPass.cpp ------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
19#include "llzk/Util/Concepts.h"
21
22#include <mlir/IR/BuiltinOps.h>
23
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
28
29#include <deque>
30#include <memory>
31
32// Include the generated base pass class definitions.
33namespace llzk {
34#define GEN_PASS_DEF_REDUNDANTREADANDWRITEELIMINATIONPASS
36} // namespace llzk
37
38using namespace mlir;
39using namespace llzk;
40using namespace llzk::array;
41using namespace llzk::felt;
42using namespace llzk::function;
43using namespace llzk::component;
44
45#define DEBUG_TYPE "llzk-redundant-read-write-pass"
46
47namespace {
48
51class ReferenceID {
52public:
53 explicit ReferenceID(Value v) {
54 // reserved special pointer values for DenseMapInfo
55 if (v.getImpl() == reinterpret_cast<mlir::detail::ValueImpl *>(1) ||
56 v.getImpl() == reinterpret_cast<mlir::detail::ValueImpl *>(2)) {
57 identifier = v;
58 } else if (auto constVal = dyn_cast_if_present<FeltConstantOp>(v.getDefiningOp())) {
59 identifier = constVal.getValue().getValue();
60 } else if (auto constIdxVal = dyn_cast_if_present<arith::ConstantIndexOp>(v.getDefiningOp())) {
61 identifier = llvm::cast<IntegerAttr>(constIdxVal.getValue()).getValue();
62 } else {
63 identifier = v;
64 }
65 }
66 explicit ReferenceID(FlatSymbolRefAttr s) : identifier(s) {}
67 explicit ReferenceID(APInt i) : identifier(i) {}
68 explicit ReferenceID(unsigned i) : identifier(APInt(64, i)) {}
69
70 bool isValue() const { return std::holds_alternative<Value>(identifier); }
71 bool isSymbol() const { return std::holds_alternative<FlatSymbolRefAttr>(identifier); }
72 bool isConst() const { return std::holds_alternative<APInt>(identifier); }
73
74 Value getValue() const {
75 ensure(isValue(), "does not hold Value");
76 return std::get<Value>(identifier);
77 }
78
79 FlatSymbolRefAttr getSymbol() const {
80 ensure(isSymbol(), "does not hold symbol");
81 return std::get<FlatSymbolRefAttr>(identifier);
82 }
83
84 APInt getConst() const {
85 ensure(isConst(), "does not hold const");
86 return std::get<APInt>(identifier);
87 }
88
89 void print(raw_ostream &os) const {
90 if (auto v = std::get_if<Value>(&identifier)) {
91 if (auto opres = dyn_cast<OpResult>(*v)) {
92 os << '%' << opres.getResultNumber();
93 } else {
94 os << *v;
95 }
96 } else if (auto s = std::get_if<FlatSymbolRefAttr>(&identifier)) {
97 os << *s;
98 } else {
99 os << std::get<APInt>(identifier);
100 }
101 }
102
103 friend bool operator==(const ReferenceID &lhs, const ReferenceID &rhs) {
104 return lhs.identifier == rhs.identifier;
105 }
106
107 friend raw_ostream &operator<<(raw_ostream &os, const ReferenceID &id) {
108 id.print(os);
109 return os;
110 }
111
112private:
117 std::variant<FlatSymbolRefAttr, APInt, Value> identifier;
118};
119
120} // namespace
121
122namespace llvm {
123
125template <> struct DenseMapInfo<ReferenceID> {
126 static ReferenceID getEmptyKey() {
127 return ReferenceID(mlir::Value(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
128 }
129 static inline ReferenceID getTombstoneKey() {
130 return ReferenceID(mlir::Value(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
131 }
132 static unsigned getHashValue(const ReferenceID &r) {
133 if (r.isValue()) {
134 return hash_value(r.getValue());
135 } else if (r.isSymbol()) {
136 return hash_value(r.getSymbol());
137 }
138 return hash_value(r.getConst());
139 }
140 static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs) { return lhs == rhs; }
141};
142
143} // namespace llvm
144
145namespace {
146
165class ReferenceNode {
166public:
167 template <typename IdType> static std::shared_ptr<ReferenceNode> create(IdType id, Value v) {
168 ReferenceNode n(id, v);
169 // Need the move constructor version since constructor is private
170 return std::make_shared<ReferenceNode>(std::move(n));
171 }
172
175 std::shared_ptr<ReferenceNode> clone(bool withChildren = true) const {
176 ReferenceNode copy(identifier, storedValue);
177 copy.updateLastWrite(lastWrite);
178 if (withChildren) {
179 for (const auto &[id, child] : children) {
180 copy.children[id] = child->clone(withChildren);
181 }
182 }
183 return std::make_shared<ReferenceNode>(std::move(copy));
184 }
185
186 template <typename IdType>
187 std::shared_ptr<ReferenceNode>
188 createChild(IdType id, Value storedVal, std::shared_ptr<ReferenceNode> valTree = nullptr) {
189 std::shared_ptr<ReferenceNode> child = create(id, storedVal);
190 child->setCurrentValue(storedVal, valTree);
191 children[child->identifier] = child;
192 return child;
193 }
194
197 template <typename IdType> std::shared_ptr<ReferenceNode> getChild(IdType id) const {
198 auto it = children.find(ReferenceID(id));
199 if (it != children.end()) {
200 return it->second;
201 }
202 return nullptr;
203 }
204
208 template <typename IdType>
209 std::shared_ptr<ReferenceNode> getOrCreateChild(IdType id, Value storedVal = nullptr) {
210 auto it = children.find(ReferenceID(id));
211 if (it != children.end()) {
212 return it->second;
213 }
214 return createChild(id, storedVal);
215 }
216
219 Operation *updateLastWrite(Operation *writeOp) {
220 Operation *old = lastWrite;
221 lastWrite = writeOp;
222 return old;
223 }
224
225 void setCurrentValue(Value v, std::shared_ptr<ReferenceNode> valTree = nullptr) {
226 storedValue = v;
227 if (valTree != nullptr) {
228 // Overwrite our current set of children with new children, since we overwrote
229 // the stored value.
230 children = valTree->children;
231 }
232 }
233
234 void invalidateChildren() { children.clear(); }
235
236 bool isLeaf() const { return children.empty(); }
237
238 Value getStoredValue() const { return storedValue; }
239
240 bool hasStoredValue() const { return storedValue != nullptr; }
241
242 void print(raw_ostream &os, int indent = 0) const {
243 os.indent(indent) << '[' << identifier;
244 if (storedValue != nullptr) {
245 os << " => " << storedValue;
247 os << ']';
248 if (!children.empty()) {
249 os << "{\n";
250 for (auto &[_, child] : children) {
251 child->print(os, indent + 4);
252 os << '\n';
253 }
254 os.indent(indent) << '}';
256 }
258 friend raw_ostream &operator<<(raw_ostream &os, const ReferenceNode &r) {
259 r.print(os);
260 return os;
261 }
262
263
264 friend bool
265 topLevelEq(const std::shared_ptr<ReferenceNode> &lhs, const std::shared_ptr<ReferenceNode> &rhs) {
266 return lhs->identifier == rhs->identifier && lhs->storedValue == rhs->storedValue &&
267 lhs->lastWrite == rhs->lastWrite;
268 }
269
270 friend std::shared_ptr<ReferenceNode> greatestCommonSubtree(
271 const std::shared_ptr<ReferenceNode> &lhs, const std::shared_ptr<ReferenceNode> &rhs
272 ) {
273 if (!topLevelEq(lhs, rhs)) {
274 return nullptr;
275 }
276 auto res = lhs->clone(false); // childless clone
277 // Find common children and recurse
278 for (auto &[id, lhsChild] : lhs->children) {
279 if (auto it = rhs->children.find(id); it != rhs->children.end()) {
280 auto &rhsChild = it->second;
281 if (auto gcs = greatestCommonSubtree(lhsChild, rhsChild)) {
282 res->children[id] = gcs;
283 }
284 }
285 }
286 return res;
287 }
288
289private:
290 ReferenceID identifier;
291 mlir::Value storedValue;
292 Operation *lastWrite;
293 DenseMap<ReferenceID, std::shared_ptr<ReferenceNode>> children;
294
295 template <typename IdType>
296 ReferenceNode(IdType id, Value initialVal)
297 : identifier(id), storedValue(initialVal), lastWrite(nullptr), children() {}
298};
299
300using ValueMap = DenseMap<mlir::Value, std::shared_ptr<ReferenceNode>>;
301
302ValueMap intersect(const ValueMap &lhs, const ValueMap &rhs) {
303 ValueMap res;
304 for (auto &[id, lhsValTree] : lhs) {
305 if (auto it = rhs.find(id); it != rhs.end()) {
306 auto &rhsValTree = it->second;
307 res[id] = greatestCommonSubtree(lhsValTree, rhsValTree);
308 }
309 }
310 return res;
311}
312
315ValueMap cloneValueMap(const ValueMap &orig) {
316 ValueMap res;
317 for (auto &[id, tree] : orig) {
318 res[id] = tree->clone();
319 }
320 return res;
321}
322
323class RedundantReadAndWriteEliminationPass
325 RedundantReadAndWriteEliminationPass> {
331 void runOnOperation() override {
332 getOperation().walk([&](FuncDefOp fn) { runOnFunc(fn); });
333 }
334
337 void runOnFunc(FuncDefOp fn) {
338 // Nothing to do for body-less functions.
339 if (fn.getCallableRegion() == nullptr) {
340 return;
341 }
342
343 LLVM_DEBUG(llvm::dbgs() << "Running on " << fn.getName() << '\n');
344
345 // Maps redundant value -> necessary value.
346 DenseMap<Value, Value> replacementMap;
347 // All values created by a new_* operation or from a read*/extract* operation.
348 SmallVector<Value> readVals;
349 // All writes that are either (1) overwritten by subsequent writes or (2)
350 // write a value that is already written.
351 SmallVector<Operation *> redundantWrites;
352
353 ValueMap initState;
354 // Initialize the state to the function arguments.
355 for (auto arg : fn.getArguments()) {
356 initState[arg] = ReferenceNode::create(arg, arg);
357 }
358 // Functions only have a single region
359 (void)runOnRegion(
360 *fn.getCallableRegion(), std::move(initState), replacementMap, readVals, redundantWrites
361 );
362
363 // Now that we have accumulated all necessary state, we perform the optimizations:
364 // - Replace all redundant values.
365 for (auto &[orig, replace] : replacementMap) {
366 LLVM_DEBUG(llvm::dbgs() << "replacing " << orig << " with " << orig << '\n');
367 orig.replaceAllUsesWith(replace);
368 // We save the deletion to the readVals loop to prevent double-free.
369 }
370 // -Remove redundant writes now that it is safe to do so.
371 for (auto *writeOp : redundantWrites) {
372 LLVM_DEBUG(llvm::dbgs() << "erase write: " << *writeOp << '\n');
373 writeOp->erase();
374 }
375 // - Now we do a pass over read values to see if any are now unused.
376 // We do this in reverse order to free up early reads if their users would
377 // be removed.
378 for (auto it = readVals.rbegin(); it != readVals.rend(); it++) {
379 Value readVal = *it;
380 if (readVal.use_empty()) {
381 LLVM_DEBUG(llvm::dbgs() << "erase read: " << readVal << '\n');
382 readVal.getDefiningOp()->erase();
383 }
384 }
385 }
386
387 ValueMap runOnRegion(
388 Region &r, ValueMap &&initState, DenseMap<Value, Value> &replacementMap,
389 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
390 ) {
391 // maps block -> state at the end of the block
392 DenseMap<Block *, ValueMap> endStates;
393 // The first block has no predecessors, so nullptr contains the init state
394 endStates[nullptr] = initState;
395 auto getBlockState = [&endStates](Block *blockPtr) {
396 auto it = endStates.find(blockPtr);
397 ensure(it != endStates.end(), "unknown end state means we have an unsupported backedge");
398 return cloneValueMap(it->second);
399 };
400 std::deque<Block *> frontier;
401 frontier.push_back(&r.front());
402 DenseSet<Block *> visited;
403
404 SmallVector<std::reference_wrapper<const ValueMap>> terminalStates;
405
406 while (!frontier.empty()) {
407 Block *currentBlock = frontier.front();
408 frontier.pop_front();
409 visited.insert(currentBlock);
410
411 // get predecessors
412 ValueMap currentState;
413 auto it = currentBlock->pred_begin();
414 auto itEnd = currentBlock->pred_end();
415 if (it == itEnd) {
416 // get the state for the entry block.
417 currentState = getBlockState(nullptr);
418 } else {
419 currentState = getBlockState(*it);
420 // If we have multiple predecessors, we take a pessimistic view and
421 // set the state as only the intersection of all predecessor states
422 // (e.g., only the common state from an if branch).
423 for (it++; it != itEnd; it++) {
424 currentState = intersect(currentState, getBlockState(*it));
425 }
426 }
427
428 // Run this block, consuming currentState and producing the endState
429 auto endState = runOnBlock(
430 *currentBlock, std::move(currentState), replacementMap, readVals, redundantWrites
431 );
432
433 // Update the end states.
434 // Since we only support the scf dialect, we should never have any
435 // backedges, so we should never already have state for this block.
436 ensure(endStates.find(currentBlock) == endStates.end(), "backedge");
437 endStates[currentBlock] = std::move(endState);
438
439 // add successors to frontier
440 if (currentBlock->hasNoSuccessors()) {
441 terminalStates.push_back(endStates[currentBlock]);
442 } else {
443 for (Block *succ : currentBlock->getSuccessors()) {
444 if (visited.find(succ) == visited.end()) {
445 frontier.push_back(succ);
446 }
447 }
448 }
449 }
450
451 // The final state is the intersection of all possible terminal states.
452 ensure(!terminalStates.empty(), "computed no states");
453 auto finalState = terminalStates.front().get();
454 for (auto it = terminalStates.begin() + 1; it != terminalStates.end(); it++) {
455 finalState = intersect(finalState, it->get());
456 }
457 return finalState;
458 }
459
460 ValueMap runOnBlock(
461 Block &b, ValueMap &&state, DenseMap<Value, Value> &replacementMap,
462 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
463 ) {
464 for (Operation &op : b) {
465 runOperation(&op, state, replacementMap, readVals, redundantWrites);
466 // Some operations have regions (e.g., scf.if). These regions must be
467 // traversed and the resulting state(s) are intersected for the final
468 // state of this operation.
469 if (!op.getRegions().empty()) {
470 SmallVector<ValueMap> regionStates;
471 for (Region &region : op.getRegions()) {
472 auto regionState =
473 runOnRegion(region, cloneValueMap(state), replacementMap, readVals, redundantWrites);
474 regionStates.push_back(regionState);
475 }
476
477 ValueMap finalState = regionStates.front();
478 for (auto it = regionStates.begin() + 1; it != regionStates.end(); it++) {
479 finalState = intersect(finalState, *it);
480 }
481 state = std::move(finalState);
482 }
483 }
484 return std::move(state);
485 }
486
494 void runOperation(
495 Operation *op, ValueMap &state, DenseMap<Value, Value> &replacementMap,
496 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
497 ) {
498 // Uses the replacement map to look up values to simplify later replacement.
499 // This avoids having a daisy chain of "replace B with A", "replace C with B",
500 // etc.
501 auto translate = [&replacementMap](Value v) {
502 if (auto it = replacementMap.find(v); it != replacementMap.end()) {
503 return it->second;
504 }
505 return v;
506 };
507
508 // Lookup the value tree in the current state or return nullptr.
509 auto tryGetValTree = [&state](Value v) -> std::shared_ptr<ReferenceNode> {
510 if (auto it = state.find(v); it != state.end()) {
511 return it->second;
512 }
513 return nullptr;
514 };
515
516 // Read a value from an array. This works on both readarr operations (which
517 // return a scalar value) and extractarr operations (which return a subarray).
518 auto doArrayReadLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass readarr) {
519 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(readarr.getArrRef()));
520
521 for (Value origIdx : readarr.getIndices()) {
522 Value idxVal = translate(origIdx);
523 currValTree = currValTree->getOrCreateChild(idxVal);
524 }
525
526 Value resVal = readarr.getResult();
527 if (!currValTree->hasStoredValue()) {
528 currValTree->setCurrentValue(resVal);
529 }
530
531 if (currValTree->getStoredValue() != resVal) {
532 LLVM_DEBUG(
533 llvm::dbgs() << readarr.getOperationName() << ": replace " << resVal << " with "
534 << currValTree->getStoredValue() << '\n'
535 );
536 replacementMap[resVal] = currValTree->getStoredValue();
537 } else {
538 state[resVal] = currValTree;
539 LLVM_DEBUG(
540 llvm::dbgs() << readarr.getOperationName() << ": " << resVal << " => " << *currValTree
541 << '\n'
542 );
543 }
544
545 readVals.push_back(resVal);
546 };
547
548 // Write a scalar value (for writearr) or a subarray value (for insertarr)
549 // to an array. The unique part of this operation relative to others is that
550 // we may receive a variable index (i.e., not a constant). In this case, we
551 // invalidate ajoining parts of the subtree, since it is possible that
552 // the variable index aliases one of the other elements and may or may not
553 // override that value.
554 auto doArrayWriteLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass writearr) {
555 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(writearr.getArrRef()));
556 Value newVal = translate(writearr.getRvalue());
557 std::shared_ptr<ReferenceNode> valTree = tryGetValTree(newVal);
558
559 for (Value origIdx : writearr.getIndices()) {
560 Value idxVal = translate(origIdx);
561 // This write will invalidate all children, since it may reference
562 // any number of them.
563 if (ReferenceID(idxVal).isValue()) {
564 LLVM_DEBUG(llvm::dbgs() << writearr.getOperationName() << ": invalidate alias\n");
565 currValTree->invalidateChildren();
566 }
567 currValTree = currValTree->getOrCreateChild(idxVal);
568 }
569
570 if (currValTree->getStoredValue() == newVal) {
571 LLVM_DEBUG(
572 llvm::dbgs() << writearr.getOperationName() << ": subsequent " << writearr
573 << " is redundant\n"
574 );
575 redundantWrites.push_back(writearr);
576 } else {
577 if (Operation *lastWrite = currValTree->updateLastWrite(writearr)) {
578 LLVM_DEBUG(
579 llvm::dbgs() << writearr.getOperationName() << "writearr: replacing " << lastWrite
580 << " with prior write " << *lastWrite << '\n'
581 );
582 redundantWrites.push_back(lastWrite);
583 }
584 currValTree->setCurrentValue(newVal, valTree);
585 }
586 };
587
588 // struct ops
589 if (auto newStruct = dyn_cast<CreateStructOp>(op)) {
590 // For new values, the "stored value" of the reference is the creation site.
591 auto structVal = ReferenceNode::create(newStruct, newStruct);
592 state[newStruct] = structVal;
593 LLVM_DEBUG(llvm::dbgs() << newStruct.getOperationName() << ": " << *state[newStruct] << '\n');
594 // adding this to readVals
595 readVals.push_back(newStruct);
596 } else if (auto readf = dyn_cast<FieldReadOp>(op)) {
597 auto structVal = state.at(translate(readf.getComponent()));
598 FlatSymbolRefAttr symbol = readf.getFieldNameAttr();
599 Value resVal = translate(readf.getVal());
600 // Check if such a child already exists.
601 if (auto child = structVal->getChild(symbol)) {
602 LLVM_DEBUG(
603 llvm::dbgs() << readf.getOperationName() << ": adding replacement map entry { "
604 << resVal << " => " << child->getStoredValue() << " }\n"
605 );
606 replacementMap[resVal] = child->getStoredValue();
607 } else {
608 // If we have no previous store, we create a new symbolic value for
609 // this location.
610 state[readf] = structVal->createChild(symbol, resVal);
611 LLVM_DEBUG(llvm::dbgs() << readf.getOperationName() << ": " << *state[readf] << '\n');
612 }
613 // specifically add the untranslated value back for removal checks
614 readVals.push_back(readf.getVal());
615 } else if (auto writef = dyn_cast<FieldWriteOp>(op)) {
616 auto structVal = state.at(translate(writef.getComponent()));
617 Value writeVal = translate(writef.getVal());
618 FlatSymbolRefAttr symbol = writef.getFieldNameAttr();
619 auto valTree = tryGetValTree(writeVal);
620
621 auto child = structVal->getOrCreateChild(symbol);
622 if (child->getStoredValue() == writeVal) {
623 LLVM_DEBUG(
624 llvm::dbgs() << writef.getOperationName() << ": recording redundant write " << writef
625 << '\n'
626 );
627 redundantWrites.push_back(writef);
628 } else {
629 if (auto *lastWrite = child->updateLastWrite(writef)) {
630 LLVM_DEBUG(
631 llvm::dbgs() << writef.getOperationName() << ": recording overwritten write "
632 << *lastWrite << '\n'
633 );
634 redundantWrites.push_back(lastWrite);
635 }
636 child->setCurrentValue(writeVal, valTree);
637 LLVM_DEBUG(
638 llvm::dbgs() << writef.getOperationName() << ": " << *child << " set to " << writeVal
639 << '\n'
640 );
641 }
642 }
643 // array ops
644 else if (auto newArray = dyn_cast<CreateArrayOp>(op)) {
645 auto arrayVal = ReferenceNode::create(newArray, newArray);
646 state[newArray] = arrayVal;
647
648 // If we're given a constructor, we can instantiate elements using
649 // constant indices.
650 unsigned idx = 0;
651 for (auto elem : newArray.getElements()) {
652 Value elemVal = translate(elem);
653 auto valTree = tryGetValTree(elemVal);
654 auto elemChild = arrayVal->createChild(idx, elemVal, valTree);
655 LLVM_DEBUG(
656 llvm::dbgs() << newArray.getOperationName() << ": element " << idx << " initialized to "
657 << *elemChild << '\n'
658 );
659 idx++;
660 }
661
662 readVals.push_back(newArray);
663 } else if (auto readarr = dyn_cast<ReadArrayOp>(op)) {
664 doArrayReadLike(readarr);
665 } else if (auto writearr = dyn_cast<WriteArrayOp>(op)) {
666 doArrayWriteLike(writearr);
667 } else if (auto extractarr = dyn_cast<ExtractArrayOp>(op)) {
668 // Logic is essentially the same as readarr
669 doArrayReadLike(extractarr);
670 } else if (auto insertarr = dyn_cast<InsertArrayOp>(op)) {
671 // Logic is essentially the same as writearr
672 doArrayWriteLike(insertarr);
673 }
674 }
675};
676
677} // namespace
678
680 return std::make_unique<RedundantReadAndWriteEliminationPass>();
681};
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
Definition Ops.h.inc:588
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ConstrainRef &rhs)
void ensure(bool condition, llvm::Twine errMsg)
Definition ErrorHelper.h:35
std::unique_ptr< mlir::Pass > createRedundantReadAndWriteEliminationPass()
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs)