21#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
22#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
23#include <mlir/Analysis/DataFlowFramework.h>
24#include <mlir/Dialect/SCF/IR/SCF.h>
25#include <mlir/IR/Attributes.h>
26#include <mlir/IR/Operation.h>
27#include <mlir/IR/Region.h>
28#include <mlir/IR/SymbolTable.h>
29#include <mlir/IR/Value.h>
30#include <mlir/IR/ValueRange.h>
31#include <mlir/Interfaces/CallInterfaces.h>
32#include <mlir/Interfaces/ControlFlowInterfaces.h>
33#include <mlir/Support/LLVM.h>
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/Support/Casting.h>
42using namespace mlir::dataflow;
52 : DataFlowAnalysis(s) {
53 registerAnchorKind<CFGEdge>();
59 for (Region ®ion : top->getRegions()) {
63 for (Value argument : region.front().getArguments()) {
68 return initializeRecursively(top);
71LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
74 if (failed(visitOperation(op))) {
78 for (Region ®ion : op->getRegions()) {
79 for (Block &block : region) {
80 getOrCreate<Executable>(getProgramPointBefore(&block))->blockContentSubscribe(
this);
83 for (Operation &containedOp : block) {
84 if (failed(initializeRecursively(&containedOp))) {
95 if (!point->isBlockStart()) {
96 return visitOperation(point->getPrevOp());
98 visitBlock(point->getBlock());
102LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
108 if (op->getBlock() !=
nullptr &&
109 !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) {
114 SmallVector<AbstractSparseLattice *> resultLattices;
115 resultLattices.reserve(op->getNumResults());
116 for (Value result : op->getResults()) {
118 resultLattices.push_back(resultLattice);
122 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
123 visitRegionSuccessors(
124 getProgramPointAfter(branch), branch,
125 RegionBranchPoint::parent(), resultLattices
131 SmallVector<const AbstractSparseLattice *> operandLattices;
132 operandLattices.reserve(op->getNumOperands());
133 for (Value operand : op->getOperands()) {
135 operandLattice->useDefSubscribe(
this);
136 operandLattices.push_back(operandLattice);
143 // If the call operation is to an external function, attempt to infer the
144 // results from the call arguments.
145 auto callable = resolveCallable<FuncDefOp>(tables, call);
146 if (!getSolverConfig().isInterprocedural() ||
147 (succeeded(callable) && !callable->get().getCallableRegion())) {
148 visitExternalCallImpl(call, operandLattices, resultLattices);
152 // Otherwise, the results of a call operation are determined by the
156 SmallVector<Operation *> predecessors;
157 callable->get().walk([&predecessors](ReturnOp ret) mutable { predecessors.push_back(ret); });
159 // If not all return sites are known, then conservatively assume we can't
160 // reason about the data-flow.
161 if (predecessors.empty()) {
162 setAllToEntryStates(resultLattices);
165 for (Operation *predecessor : predecessors) {
166 for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) {
167 join(resLattice, *getLatticeElementFor(getProgramPointAfter(op), operand));
180void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
182 if (block->getNumArguments() == 0) {
187 if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive()) {
192 SmallVector<AbstractSparseLattice *> argLattices;
193 argLattices.reserve(block->getNumArguments());
194 for (BlockArgument argument : block->getArguments()) {
196 argLattices.push_back(argLattice);
201 if (block->isEntryBlock()) {
208 auto moduleOpRes = getTopRootModule(callable.getOperation());
209 ensure(succeeded(moduleOpRes), "could not get root module from callable");
210 SmallVector<Operation *> callsites;
211 moduleOpRes->walk([this, &callable, &callsites](CallOp call) mutable {
212 auto calledFnRes = resolveCallable<FuncDefOp>(tables, call);
213 if (succeeded(calledFnRes) &&
214 calledFnRes->get().getCallableRegion() == callable.getCallableRegion()) {
215 callsites.push_back(call);
218 // If not all callsites are known, conservatively mark all lattices as
219 // having reached their pessimistic fixpoints.
220 if (callsites.empty() || !getSolverConfig().isInterprocedural()) {
221 return setAllToEntryStates(argLattices);
223 for (Operation *callsite : callsites) {
224 auto call = cast<CallOpInterface>(callsite);
225 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) {
227 std::get<1>(it), *getLatticeElementFor(getProgramPointBefore(block), std::get<0>(it))
236 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
237 return visitRegionSuccessors(
238 getProgramPointBefore(block), branch, block->getParent(), argLattices
244 block->getParentOp(), RegionSuccessor(block->getParent()), argLattices, 0
249 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
250 Block *predecessor = *it;
254 auto *edgeExecutable = getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
255 edgeExecutable->blockContentSubscribe(
this);
256 if (!edgeExecutable->isLive()) {
261 if (
auto branch = dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
262 SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex());
263 for (
auto [idx, lattice] : llvm::enumerate(argLattices)) {
264 if (Value operand = operands[idx]) {
279void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
280 ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor,
281 ArrayRef<AbstractSparseLattice *> lattices
283 Operation *op = point->isBlockStart() ? point->getBlock()->getParentOp() : point->getPrevOp();
287 std::optional<OperandRange> operands;
291 operands = branch.getEntrySuccessorOperands(successor);
293 }
else if (
auto regionTerminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
294 operands = regionTerminator.getSuccessorOperands(successor);
305 if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
306 inputs = forOp.getRegionIterArgs();
307 }
else if (
auto whileOp = dyn_cast<scf::WhileOp>(op)) {
308 inputs = whileOp.getRegionIterArgs();
311 if (inputs.size() != operands->size()) {
316 unsigned firstIndex = 0;
317 if (inputs.size() != lattices.size()) {
318 if (!point->isBlockStart()) {
319 if (!inputs.empty()) {
320 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
323 branch, RegionSuccessor(branch->getResults().slice(firstIndex, inputs.size())),
327 if (!inputs.empty()) {
328 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
330 Region *region = point->getBlock()->getParent();
333 RegionSuccessor(region, region->getArguments().slice(firstIndex, inputs.size())),
339 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) {
348 addDependency(state, point);
353 ArrayRef<AbstractSparseLattice *> lattices
363 propagateIfChanged(lhs, lhs->join(rhs));
This file implements sparse data-flow analysis using the data-flow analysis framework.
mlir::LogicalResult visit(mlir::ProgramPoint *point) override
Visit a program point.
mlir::LogicalResult initialize(mlir::Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual mlir::LogicalResult visitOperationImpl(mlir::Operation *op, mlir::ArrayRef< const AbstractSparseLattice * > operandLattices, mlir::ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual AbstractSparseLattice * getLatticeElement(mlir::Value value)=0
Get the lattice element of a value.
AbstractSparseForwardDataFlowAnalysis(mlir::DataFlowSolver &solver)
void setAllToEntryStates(mlir::ArrayRef< AbstractSparseLattice * > lattices)
const AbstractSparseLattice * getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(mlir::Operation *op, const mlir::RegionSuccessor &successor, mlir::ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
mlir::dataflow::AbstractSparseLattice AbstractSparseLattice