15#include <mlir/IR/BuiltinOps.h>
17#include <llvm/ADT/DepthFirstIterator.h>
18#include <llvm/ADT/SmallSet.h>
19#include <llvm/Support/GraphWriter.h>
30 assert(!node->parent &&
"def cannot be in more than one symbol table");
32 children.insert(node);
41void assertProperBuild(SymbolOpInterface root,
const SymbolDefTree *tree) {
43 llvm::SmallSet<SymbolOpInterface, 16> fromGraph;
45 if (SymbolOpInterface s = r->getOp()) {
51 root.walk([&fromGraph](SymbolOpInterface s) { assert(fromGraph.contains(s)); });
58 assert(rootSymbol->hasTrait<OpTrait::SymbolTable>());
59 buildTree(rootSymbol,
nullptr);
60 assertProperBuild(rootSymbol,
this);
63void SymbolDefTree::buildTree(SymbolOpInterface symbolOp,
SymbolDefTreeNode *parentNode) {
65 parentNode = getOrAddNode(symbolOp, parentNode);
67 if (symbolOp->hasTrait<OpTrait::SymbolTable>()) {
68 for (Operation &op : symbolOp->getRegion(0).front()) {
69 if (SymbolOpInterface childSym = llvm::dyn_cast<SymbolOpInterface>(&op)) {
70 buildTree(childSym, parentNode);
77SymbolDefTree::getOrAddNode(SymbolOpInterface symbolDef,
SymbolDefTreeNode *parentNode) {
78 std::unique_ptr<SymbolDefTreeNode> &node = nodes[symbolDef];
80 node.reset(
new SymbolDefTreeNode(symbolDef));
83 parentNode->addChild(node.get());
85 root.addChild(node.get());
92 const auto *it = nodes.find(symbolDef);
93 return it == nodes.
end() ? nullptr : it->second.get();
103 os <<
'\'' << symbolDef->getName() <<
"' ";
105 os <<
"named " << name <<
'\n';
107 os <<
"without a name\n";
114 os <<
"// - Node : [" << node <<
"] ";
117 os <<
"// --- Children : [";
118 llvm::interleaveComma(node->children, os);
126 os <<
"// ---- SymbolDefTree ----\n";
130 os <<
"// -----------------------\n";
135 llvm::WriteGraph(
this,
"SymbolDefTree",
false, title, filename);
child_iterator end() const
std::string toString() const
Print the node in a human readable format.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
SymbolDefTree(mlir::SymbolOpInterface root)
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
static std::string getGraphName(GraphType)