LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
WalkPatternRewriteDriver.cpp
Go to the documentation of this file.
1//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9// See https://llvm.org/LICENSE.txt for license information.
10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11//
12//===----------------------------------------------------------------------===//
13//
14// Implements mlir::walkAndApplyPatterns.
15//
16// This file has been ported from a more recent version of LLVM with no changes.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/ErrorHandling.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/OperationSupport.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/IR/Verifier.h"
28#include "mlir/IR/Visitors.h"
29#include "mlir/Rewrite/PatternApplicator.h"
30
31#define DEBUG_TYPE "walk-rewriter"
32
33namespace mlir {
34
35namespace {
36struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> {
37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
38 using ActionImpl::ActionImpl;
39 static constexpr StringLiteral tag = "walk-and-apply-patterns";
40 void print(raw_ostream &os) const override { os << tag; }
41};
42
43#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
44// Forwarding listener to guard against unsupported erasures of non-descendant
45// ops/blocks. Because we use walk-based pattern application, erasing the
46// op/block from the *next* iteration (e.g., a user of the visited op) is not
47// valid. Note that this is only used with expensive pattern API checks.
48struct ErasedOpsListener final : RewriterBase::ForwardingListener {
49 using RewriterBase::ForwardingListener::ForwardingListener;
50
51 void notifyOperationErased(Operation *op) override {
52 checkErasure(op);
53 ForwardingListener::notifyOperationErased(op);
54 }
55
56 void notifyBlockErased(Block *block) override {
57 checkErasure(block->getParentOp());
58 ForwardingListener::notifyBlockErased(block);
59 }
60
61 void checkErasure(Operation *op) const {
62 Operation *ancestorOp = op;
63 while (ancestorOp && ancestorOp != visitedOp) {
64 ancestorOp = ancestorOp->getParentOp();
65 }
66
67 if (ancestorOp != visitedOp) {
68 llvm::report_fatal_error("unsupported erasure in WalkPatternRewriter; "
69 "erasure is only supported for matched ops and their descendants");
70 }
71 }
72
73 Operation *visitedOp = nullptr;
74};
75#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
76} // namespace
77
79 Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener
80) {
81#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82 if (failed(verify(op))) {
83 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
84 }
85#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
86
87 struct PatternRewriterImpl : public PatternRewriter {
88 PatternRewriterImpl(MLIRContext *ctx) : PatternRewriter(ctx) {}
89 };
90
91 MLIRContext *ctx = op->getContext();
92 PatternRewriterImpl rewriter(ctx);
93#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
94 ErasedOpsListener erasedListener(listener);
95 rewriter.setListener(&erasedListener);
96#else
97 rewriter.setListener(listener);
98#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99
100 PatternApplicator applicator(patterns);
101 applicator.applyDefaultCostModel();
102
103 ctx->executeAction<WalkAndApplyPatternsAction>([&] {
104 for (Region &region : op->getRegions()) {
105 region.walk([&](Operation *visitedOp) {
106 LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
107 visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
108 llvm::dbgs() << "\n";);
109#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110 erasedListener.visitedOp = visitedOp;
111#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
112 if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
113 LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
114 }
115 });
116 }
117 }, {op});
118
119#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
120 if (failed(verify(op))) {
121 llvm::report_fatal_error("walk pattern rewriter result IR failed to verify");
122 }
123#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
124}
125
126} // namespace mlir
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener)
A fast walk-based pattern rewrite driver.