LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolDefTree.cpp
Go to the documentation of this file.
1//===-- SymbolDefTree.cpp ---------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11#include "llzk/Util/Constants.h"
14
15#include <mlir/IR/BuiltinOps.h>
16
17#include <llvm/ADT/DepthFirstIterator.h>
18#include <llvm/ADT/SmallSet.h>
19#include <llvm/Support/GraphWriter.h>
20
21using namespace mlir;
22
23namespace llzk {
24
25//===----------------------------------------------------------------------===//
26// SymbolDefTreeNode
27//===----------------------------------------------------------------------===//
28
29void SymbolDefTreeNode::addChild(SymbolDefTreeNode *node) {
30 assert(!node->parent && "def cannot be in more than one symbol table");
31 node->parent = this;
32 children.insert(node);
33}
34
35//===----------------------------------------------------------------------===//
36// SymbolDefTree
37//===----------------------------------------------------------------------===//
38
39namespace {
40
41void assertProperBuild(SymbolOpInterface root, const SymbolDefTree *tree) {
42 // Collect all Symbols in the graph
43 llvm::SmallSet<SymbolOpInterface, 16> fromGraph;
44 for (const SymbolDefTreeNode *r : llvm::depth_first(tree)) {
45 if (SymbolOpInterface s = r->getOp()) {
46 fromGraph.insert(s);
47 }
48 }
49 // Ensure every symbol reachable from the 'root' is represented in the graph
50#ifndef NDEBUG
51 root.walk([&fromGraph](SymbolOpInterface s) { assert(fromGraph.contains(s)); });
52#endif
53}
54
55} // namespace
56
57SymbolDefTree::SymbolDefTree(SymbolOpInterface rootSymbol) {
58 assert(rootSymbol->hasTrait<OpTrait::SymbolTable>());
59 buildTree(rootSymbol, /*parentNode=*/nullptr);
60 assertProperBuild(rootSymbol, this);
61}
62
63void SymbolDefTree::buildTree(SymbolOpInterface symbolOp, SymbolDefTreeNode *parentNode) {
64 // Add node for the current symbol
65 parentNode = getOrAddNode(symbolOp, parentNode);
66 // If this symbol is also its own SymbolTable, recursively add child symbols
67 if (symbolOp->hasTrait<OpTrait::SymbolTable>()) {
68 for (Operation &op : symbolOp->getRegion(0).front()) {
69 if (SymbolOpInterface childSym = llvm::dyn_cast<SymbolOpInterface>(&op)) {
70 buildTree(childSym, parentNode);
71 }
72 }
73 }
74}
75
77SymbolDefTree::getOrAddNode(SymbolOpInterface symbolDef, SymbolDefTreeNode *parentNode) {
78 std::unique_ptr<SymbolDefTreeNode> &node = nodes[symbolDef];
79 if (!node) {
80 node.reset(new SymbolDefTreeNode(symbolDef));
81 // Add this node to the given parent node if given, else the root node.
82 if (parentNode) {
83 parentNode->addChild(node.get());
84 } else {
85 root.addChild(node.get());
86 }
87 }
88 return node.get();
89}
90
91const SymbolDefTreeNode *SymbolDefTree::lookupNode(SymbolOpInterface symbolDef) const {
92 const auto *it = nodes.find(symbolDef);
93 return it == nodes.end() ? nullptr : it->second.get();
94}
95
96//===----------------------------------------------------------------------===//
97// Printing
98//===----------------------------------------------------------------------===//
99
100std::string SymbolDefTreeNode::toString() const { return buildStringViaPrint(*this); }
101
102void SymbolDefTreeNode::print(llvm::raw_ostream &os) const {
103 os << '\'' << symbolDef->getName() << "' ";
104 if (StringAttr name = llzk::getSymbolName(symbolDef)) {
105 os << "named " << name << '\n';
106 } else {
107 os << "without a name\n";
108 }
109}
110
111void SymbolDefTree::print(llvm::raw_ostream &os) const {
112 std::function<void(SymbolDefTreeNode *)> printNode = [&os, &printNode](SymbolDefTreeNode *node) {
113 // Print the current node
114 os << "// - Node : [" << node << "] ";
115 node->print(os);
116 // Print list of IDs for the children
117 os << "// --- Children : [";
118 llvm::interleaveComma(node->children, os);
119 os << "]\n";
120 // Recursively print the children
121 for (SymbolDefTreeNode *c : node->children) {
122 printNode(c);
123 }
124 };
125
126 os << "// ---- SymbolDefTree ----\n";
127 for (SymbolDefTreeNode *r : root.children) {
128 printNode(r);
129 }
130 os << "// -----------------------\n";
131}
132
133void SymbolDefTree::dumpToDotFile(std::string filename) const {
135 llvm::WriteGraph(this, "SymbolDefTree", /*ShortNames*/ false, title, filename);
136}
137
138} // namespace llzk
child_iterator end() const
std::string toString() const
Print the node in a human readable format.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
SymbolDefTree(mlir::SymbolOpInterface root)
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...