15#include <mlir/IR/BuiltinOps.h>
17#include <llvm/ADT/DepthFirstIterator.h>
18#include <llvm/ADT/SmallSet.h>
19#include <llvm/Support/GraphWriter.h>
30 assert(!node->parent &&
"def cannot be in more than one symbol table");
32 children.insert(node);
41void assertProperBuild(SymbolOpInterface root,
const SymbolDefTree *tree) {
43 llvm::SmallSet<SymbolOpInterface, 16> fromGraph;
45 if (SymbolOpInterface s = r->getOp()) {
50 root.walk([&fromGraph](SymbolOpInterface s) { assert(fromGraph.contains(s)); });
56 assert(rootSymbol->hasTrait<OpTrait::SymbolTable>());
57 buildTree(rootSymbol,
nullptr);
58 assertProperBuild(rootSymbol,
this);
61void SymbolDefTree::buildTree(SymbolOpInterface symbolOp,
SymbolDefTreeNode *parentNode) {
63 parentNode = getOrAddNode(symbolOp, parentNode);
65 if (symbolOp->hasTrait<OpTrait::SymbolTable>()) {
66 for (Operation &op : symbolOp->getRegion(0).front()) {
67 if (SymbolOpInterface childSym = llvm::dyn_cast<SymbolOpInterface>(&op)) {
68 buildTree(childSym, parentNode);
75SymbolDefTree::getOrAddNode(SymbolOpInterface symbolDef,
SymbolDefTreeNode *parentNode) {
76 std::unique_ptr<SymbolDefTreeNode> &node = nodes[symbolDef];
78 node.reset(
new SymbolDefTreeNode(symbolDef));
81 parentNode->addChild(node.get());
83 root.addChild(node.get());
90 const auto *it = nodes.find(symbolDef);
91 return it == nodes.
end() ? nullptr : it->second.get();
101 os <<
'\'' << symbolDef->getName() <<
"' ";
103 os <<
"named " <<
name <<
'\n';
105 os <<
"without a name\n";
112 os <<
"// - Node : [" << node <<
"] ";
115 os <<
"// --- Children : [";
116 llvm::interleaveComma(node->children, os);
124 os <<
"// ---- SymbolDefTree ----\n";
128 os <<
"// -----------------------\n";
133 llvm::WriteGraph(
this,
"SymbolDefTree",
false, title, filename);
child_iterator end() const
std::string toString() const
Print the node in a human readable format.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
SymbolDefTree(mlir::SymbolOpInterface root)
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
static std::string getGraphName(GraphType)