19#include <mlir/IR/BuiltinOps.h>
21#include <llvm/ADT/SmallPtrSet.h>
22#include <llvm/ADT/SmallSet.h>
23#include <llvm/Support/GraphWriter.h>
34 if (this->successors.insert(node)) {
35 node->predecessors.insert(
this);
40 if (this->successors.remove(node)) {
41 node->predecessors.remove(
this);
45FailureOr<SymbolLookupResultUntyped>
52 if (succeeded(res) || !reportMissing) {
56 return lookupFrom->emitError().append(
57 "Could not find symbol referenced in UseGraph: ",
getSymbolPath()
68R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
76 auto diag = defOp.emitError(
"in SymbolUseGraph, failed to build symbol path");
77 diag.attachNote(defOp.getLoc()).append(
"for this SymbolOp");
81 return callback(foundRoot, path.value());
87 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
88 buildGraph(rootSymbolOp);
92SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(
const SymbolTable::SymbolUse &u) {
94 return getPathAndCall<SymbolUseGraphNode *>(
95 userSymbol, [
this, &userSymbol](ModuleOp r, SymbolRefAttr p) {
96 auto n = this->getOrAddNode(r, p,
nullptr);
97 n->opsThatUseTheSymbol.insert(userSymbol);
103void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
104 auto walkFn = [
this](Operation *op, bool) {
105 assert(op->hasTrait<OpTrait::SymbolTable>());
107 if (failed(opRootModule)) {
111 SymbolTableCollection tables;
114 for (SymbolTable::SymbolUse u : usesOpt.value()) {
115 bool isStructParam =
false;
116 Operation *user = u.getUser();
117 SymbolRefAttr symRef = u.getSymbolRef();
121 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
122 if (
auto fref = llvm::dyn_cast<component::FieldRefOpInterface>(user);
123 fref && fref.getFieldNameAttr() == flatSymRef) {
126 StringAttr localName = flatSymRef.getAttr();
127 isStructParam = userStruct.hasParamNamed(localName);
128 if (isStructParam || tables.getSymbolTable(userStruct).lookup(localName)) {
132 assert(succeeded(parentPath));
137 auto node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
138 node->isStructConstParam = isStructParam;
139 node->opsThatUseTheSymbol.insert(user);
143 SymbolTable::walkSymbolTables(symbolOp.getOperation(),
true, walkFn);
146 for (SymbolUseGraphNode *n :
nodesIter()) {
147 if (!n->hasSuccessor()) {
148 n->addSuccessor(&tail);
156 NodeMapKeyT key = std::make_pair(pathRoot, path);
157 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
159 nodeRef.reset(
new SymbolUseGraphNode(pathRoot, path));
162 if (predecessorNode) {
163 predecessorNode->addSuccessor(nodeRef.get());
165 root.addSuccessor(nodeRef.get());
167 }
else if (predecessorNode) {
170 SymbolUseGraphNode *node = nodeRef.get();
171 predecessorNode->addSuccessor(node);
172 if (node != predecessorNode) {
173 root.removeSuccessor(node);
176 return nodeRef.get();
180 NodeMapKeyT key = std::make_pair(pathRoot, path);
181 const auto *it = nodes.find(key);
182 return it == nodes.end() ? nullptr : it->second.get();
186 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [
this](ModuleOp r, SymbolRefAttr p) {
201inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
204 if (succeeded(unambiguousRoot)) {
205 os << unambiguousRoot.value() <<
'\n';
207 os <<
"<<unknown path>>\n";
210 os <<
"<<NULL MODULE>>\n";
217 llvm::raw_ostream &os,
bool showLocations, std::string locationLinePrefix
219 os <<
'\'' << symbolPath <<
'\'';
220 if (isStructConstParam) {
221 os <<
" (struct param)";
223 os <<
" with root module ";
224 safeAppendPathRoot(os, symbolPathRoot);
229 llvm::SmallSet<mlir::Location, 3, LocationComparator> locations;
231 locations.insert(user->getLoc());
233 for (Location loc : locations) {
234 os << locationLinePrefix << loc <<
'\n';
243 SmallPtrSet<SymbolUseGraphNode *, 16> done;
248 if (!done.insert(node).second) {
252 os <<
"// - Node : [" << node <<
"] ";
253 node->print(os,
true,
"// --- ");
255 os <<
"// --- Predecessors : [";
256 llvm::interleaveComma(
257 llvm::make_filter_range(
263 os <<
"// --- Successors : [";
264 llvm::interleaveComma(node->successorIter(), os);
272 os <<
"// ---- SymbolUseGraph ----\n";
276 os <<
"// ------------------------\n";
277 assert(done.size() == this->size() &&
"All nodes were not printed!");
282 llvm::WriteGraph(
this,
"SymbolUseGraph",
false, title, filename);
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
void print(llvm::raw_ostream &os, bool showLocations=false, std::string locationLinePrefix="") const
std::string toString(bool showLocations=false) const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
const OpSet & getUserOps() const
The set of operations that use the symbol.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
static std::string getGraphName(GraphType)