LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolDefTree.cpp
Go to the documentation of this file.
1//===-- SymbolDefTree.cpp ---------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11#include "llzk/Util/Constants.h"
14
15#include <mlir/IR/BuiltinOps.h>
16
17#include <llvm/ADT/DepthFirstIterator.h>
18#include <llvm/ADT/SmallSet.h>
19#include <llvm/Support/GraphWriter.h>
20
21using namespace mlir;
22
23namespace llzk {
24
25//===----------------------------------------------------------------------===//
26// SymbolDefTreeNode
27//===----------------------------------------------------------------------===//
28
29void SymbolDefTreeNode::addChild(SymbolDefTreeNode *node) {
30 assert(!node->parent && "def cannot be in more than one symbol table");
31 node->parent = this;
32 children.insert(node);
33}
34
35//===----------------------------------------------------------------------===//
36// SymbolDefTree
37//===----------------------------------------------------------------------===//
38
39namespace {
40
41void assertProperBuild(SymbolOpInterface root, const SymbolDefTree *tree) {
42 // Collect all Symbols in the graph
43 llvm::SmallSet<SymbolOpInterface, 16> fromGraph;
44 for (const SymbolDefTreeNode *r : llvm::depth_first(tree)) {
45 if (SymbolOpInterface s = r->getOp()) {
46 fromGraph.insert(s);
47 }
48 }
49 // Ensure every symbol reachable from the 'root' is represented in the graph
50 root.walk([&fromGraph](SymbolOpInterface s) { assert(fromGraph.contains(s)); });
51}
52
53} // namespace
54
55SymbolDefTree::SymbolDefTree(SymbolOpInterface rootSymbol) {
56 assert(rootSymbol->hasTrait<OpTrait::SymbolTable>());
57 buildTree(rootSymbol, /*parentNode=*/nullptr);
58 assertProperBuild(rootSymbol, this);
59}
60
61void SymbolDefTree::buildTree(SymbolOpInterface symbolOp, SymbolDefTreeNode *parentNode) {
62 // Add node for the current symbol
63 parentNode = getOrAddNode(symbolOp, parentNode);
64 // If this symbol is also its own SymbolTable, recursively add child symbols
65 if (symbolOp->hasTrait<OpTrait::SymbolTable>()) {
66 for (Operation &op : symbolOp->getRegion(0).front()) {
67 if (SymbolOpInterface childSym = llvm::dyn_cast<SymbolOpInterface>(&op)) {
68 buildTree(childSym, parentNode);
69 }
70 }
71 }
72}
73
75SymbolDefTree::getOrAddNode(SymbolOpInterface symbolDef, SymbolDefTreeNode *parentNode) {
76 std::unique_ptr<SymbolDefTreeNode> &node = nodes[symbolDef];
77 if (!node) {
78 node.reset(new SymbolDefTreeNode(symbolDef));
79 // Add this node to the given parent node if given, else the root node.
80 if (parentNode) {
81 parentNode->addChild(node.get());
82 } else {
83 root.addChild(node.get());
84 }
85 }
86 return node.get();
87}
88
89const SymbolDefTreeNode *SymbolDefTree::lookupNode(SymbolOpInterface symbolDef) const {
90 const auto *it = nodes.find(symbolDef);
91 return it == nodes.end() ? nullptr : it->second.get();
92}
93
94//===----------------------------------------------------------------------===//
95// Printing
96//===----------------------------------------------------------------------===//
97
98std::string SymbolDefTreeNode::toString() const { return buildStringViaPrint(*this); }
99
100void SymbolDefTreeNode::print(llvm::raw_ostream &os) const {
101 os << '\'' << symbolDef->getName() << "' ";
102 if (StringAttr name = llzk::getSymbolName(symbolDef)) {
103 os << "named " << name << '\n';
104 } else {
105 os << "without a name\n";
106 }
107}
108
109void SymbolDefTree::print(llvm::raw_ostream &os) const {
110 std::function<void(SymbolDefTreeNode *)> printNode = [&os, &printNode](SymbolDefTreeNode *node) {
111 // Print the current node
112 os << "// - Node : [" << node << "] ";
113 node->print(os);
114 // Print list of IDs for the children
115 os << "// --- Children : [";
116 llvm::interleaveComma(node->children, os);
117 os << "]\n";
118 // Recursively print the children
119 for (SymbolDefTreeNode *c : node->children) {
120 printNode(c);
121 }
122 };
123
124 os << "// ---- SymbolDefTree ----\n";
125 for (SymbolDefTreeNode *r : root.children) {
126 printNode(r);
127 }
128 os << "// -----------------------\n";
129}
130
131void SymbolDefTree::dumpToDotFile(std::string filename) const {
133 llvm::WriteGraph(this, "SymbolDefTree", /*ShortNames*/ false, title, filename);
134}
135
136} // namespace llzk
MlirStringRef name
child_iterator end() const
std::string toString() const
Print the node in a human readable format.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
SymbolDefTree(mlir::SymbolOpInterface root)
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...