18#include <mlir/IR/BuiltinOps.h>
20#include <llvm/ADT/SmallPtrSet.h>
21#include <llvm/Support/GraphWriter.h>
32 if (this->successors.insert(node)) {
33 node->predecessors.insert(
this);
38 if (this->successors.remove(node)) {
39 node->predecessors.remove(
this);
43FailureOr<SymbolLookupResultUntyped>
50 if (succeeded(res) || !reportMissing) {
54 return lookupFrom->emitError().append(
55 "Could not find symbol referenced in UseGraph: ",
getSymbolPath()
66R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
74 auto diag = defOp.emitError(
"in SymbolUseGraph, failed to build symbol path");
75 diag.attachNote(defOp.getLoc()).append(
"for this SymbolOp");
79 return callback(foundRoot, path.value());
85 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
86 buildGraph(rootSymbolOp);
90SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(
const SymbolTable::SymbolUse &u) {
92 return getPathAndCall<SymbolUseGraphNode *>(userSymbol, [
this](ModuleOp r, SymbolRefAttr p) {
93 return this->getOrAddNode(r, p,
nullptr);
97void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
98 auto walkFn = [
this](Operation *op, bool) {
99 assert(op->hasTrait<OpTrait::SymbolTable>());
101 if (failed(opRootModule)) {
105 SymbolTableCollection tables;
108 for (SymbolTable::SymbolUse u : usesOpt.value()) {
109 bool isStructParam =
false;
110 SymbolRefAttr symRef = u.getSymbolRef();
114 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
115 Operation *user = u.getUser();
116 if (
auto fref = llvm::dyn_cast<component::FieldRefOpInterface>(user);
117 fref && fref.getFieldNameAttr() == flatSymRef) {
120 StringAttr localName = flatSymRef.getAttr();
121 isStructParam = userStruct.hasParamNamed(localName);
122 if (isStructParam || tables.getSymbolTable(userStruct).lookup(localName)) {
126 assert(succeeded(parentPath));
131 auto node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
132 node->isStructConstParam = isStructParam;
136 SymbolTable::walkSymbolTables(symbolOp.getOperation(),
true, walkFn);
139 for (SymbolUseGraphNode *n :
nodesIter()) {
140 if (!n->hasSuccessor()) {
141 n->addSuccessor(&tail);
149 NodeMapKeyT key = std::make_pair(pathRoot, path);
150 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
152 nodeRef.reset(
new SymbolUseGraphNode(pathRoot, path));
155 if (predecessorNode) {
156 predecessorNode->addSuccessor(nodeRef.get());
158 root.addSuccessor(nodeRef.get());
160 }
else if (predecessorNode) {
163 SymbolUseGraphNode *node = nodeRef.get();
164 predecessorNode->addSuccessor(node);
165 if (node != predecessorNode) {
166 root.removeSuccessor(node);
169 return nodeRef.get();
173 NodeMapKeyT key = std::make_pair(pathRoot, path);
174 const auto *it = nodes.find(key);
175 return it == nodes.end() ? nullptr : it->second.get();
179 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [
this](ModuleOp r, SymbolRefAttr p) {
192inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
195 if (succeeded(unambiguousRoot)) {
196 os << unambiguousRoot.value() <<
'\n';
198 os <<
"<<unknown path>>\n";
201 os <<
"<<NULL MODULE>>\n";
208 os <<
'\'' << symbolPath <<
'\'';
209 if (isStructConstParam) {
210 os <<
" (struct param)";
212 os <<
" with root module ";
213 safeAppendPathRoot(os, symbolPathRoot);
220 SmallPtrSet<SymbolUseGraphNode *, 16> done;
225 if (!done.insert(node).second) {
229 os <<
"// - Node : [" << node <<
"] ";
232 os <<
"// --- Predecessors : [";
233 llvm::interleaveComma(
234 llvm::make_filter_range(
240 os <<
"// --- Successors : [";
241 llvm::interleaveComma(node->successorIter(), os);
249 os <<
"// ---- SymbolUseGraph ----\n";
253 os <<
"// ------------------------\n";
254 assert(done.size() == this->size() &&
"All nodes were not printed!");
259 llvm::WriteGraph(
this,
"SymbolUseGraph",
false, title, filename);
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
std::string toString() const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
void print(llvm::raw_ostream &os) const
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
static std::string getGraphName(GraphType)