19#include <mlir/IR/BuiltinOps.h>
21#include <llvm/ADT/SmallPtrSet.h>
22#include <llvm/ADT/SmallSet.h>
23#include <llvm/Support/GraphWriter.h>
34 if (this->successors.insert(node)) {
35 node->predecessors.insert(
this);
40 if (this->successors.remove(node)) {
41 node->predecessors.remove(
this);
45FailureOr<SymbolLookupResultUntyped>
52 if (succeeded(res) || !reportMissing) {
56 return lookupFrom->emitError().append(
57 "Could not find symbol referenced in UseGraph: ",
getSymbolPath()
68R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
76 auto diag = defOp.emitError(
"in SymbolUseGraph, failed to build symbol path");
77 diag.attachNote(defOp.getLoc()).append(
"for this SymbolOp");
81 return callback(foundRoot, path.value());
87 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
88 buildGraph(rootSymbolOp);
92SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(
const SymbolTable::SymbolUse &u) {
94 return getPathAndCall<SymbolUseGraphNode *>(
96 [
this, &userSymbol](ModuleOp r, SymbolRefAttr p) {
97 auto n = this->getOrAddNode(r, p,
nullptr);
98 n->opsThatUseTheSymbol.insert(userSymbol);
104void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
105 auto walkFn = [
this](Operation *op, bool) {
106 assert(op->hasTrait<OpTrait::SymbolTable>());
108 if (failed(opRootModule)) {
112 SymbolTableCollection tables;
115 for (SymbolTable::SymbolUse u : usesOpt.value()) {
116 bool isStructParam =
false;
117 Operation *user = u.getUser();
118 SymbolRefAttr symRef = u.getSymbolRef();
122 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
123 if (
auto fref = llvm::dyn_cast<component::FieldRefOpInterface>(user);
124 fref && fref.getFieldNameAttr() == flatSymRef) {
127 StringAttr localName = flatSymRef.getAttr();
128 isStructParam = userStruct.hasParamNamed(localName);
129 if (isStructParam || tables.getSymbolTable(userStruct).lookup(localName)) {
133 assert(succeeded(parentPath));
138 auto node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
139 node->isStructConstParam = isStructParam;
140 node->opsThatUseTheSymbol.insert(user);
144 SymbolTable::walkSymbolTables(symbolOp.getOperation(),
true, walkFn);
147 for (SymbolUseGraphNode *n :
nodesIter()) {
148 if (!n->hasSuccessor()) {
149 n->addSuccessor(&tail);
157 NodeMapKeyT key = std::make_pair(pathRoot, path);
158 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
160 nodeRef.reset(
new SymbolUseGraphNode(pathRoot, path));
163 if (predecessorNode) {
164 predecessorNode->addSuccessor(nodeRef.get());
166 root.addSuccessor(nodeRef.get());
168 }
else if (predecessorNode) {
171 SymbolUseGraphNode *node = nodeRef.get();
172 predecessorNode->addSuccessor(node);
173 if (node != predecessorNode) {
174 root.removeSuccessor(node);
177 return nodeRef.get();
181 NodeMapKeyT key = std::make_pair(pathRoot, path);
182 const auto *it = nodes.find(key);
183 return it == nodes.end() ? nullptr : it->second.get();
187 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [
this](ModuleOp r, SymbolRefAttr p) {
202inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
205 if (succeeded(unambiguousRoot)) {
206 os << unambiguousRoot.value() <<
'\n';
208 os <<
"<<unknown path>>\n";
211 os <<
"<<NULL MODULE>>\n";
218 llvm::raw_ostream &os,
bool showLocations, std::string locationLinePrefix
220 os <<
'\'' << symbolPath <<
'\'';
221 if (isStructConstParam) {
222 os <<
" (struct param)";
224 os <<
" with root module ";
225 safeAppendPathRoot(os, symbolPathRoot);
230 llvm::SmallSet<mlir::Location, 3, LocationComparator> locations;
232 locations.insert(user->getLoc());
234 for (Location loc : locations) {
235 os << locationLinePrefix << loc <<
'\n';
244 SmallPtrSet<SymbolUseGraphNode *, 16> done;
249 if (!done.insert(node).second) {
253 os <<
"// - Node : [" << node <<
"] ";
254 node->print(os,
true,
"// --- ");
256 os <<
"// --- Predecessors : [";
257 llvm::interleaveComma(
258 llvm::make_filter_range(
264 os <<
"// --- Successors : [";
265 llvm::interleaveComma(node->successorIter(), os);
273 os <<
"// ---- SymbolUseGraph ----\n";
277 os <<
"// ------------------------\n";
278 assert(done.size() == this->size() &&
"All nodes were not printed!");
283 llvm::WriteGraph(
this,
"SymbolUseGraph",
false, title, filename);
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
void print(llvm::raw_ostream &os, bool showLocations=false, std::string locationLinePrefix="") const
std::string toString(bool showLocations=false) const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
const OpSet & getUserOps() const
The set of operations that use the symbol.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
static std::string getGraphName(GraphType)