LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolDefTree.h
Go to the documentation of this file.
1//===-- SymbolDefTree.h -----------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
12#include <mlir/IR/SymbolTable.h>
13
14#include <llvm/ADT/GraphTraits.h>
15#include <llvm/ADT/MapVector.h>
16#include <llvm/ADT/SetVector.h>
17#include <llvm/Support/DOTGraphTraits.h>
18
19namespace llzk {
20
21class SymbolDefTreeNode {
22 // The Symbol operation referenced by this node.
23 mlir::SymbolOpInterface symbolDef;
24
25 /* Tree structure. The SymbolDefTree owns the nodes so just pointers here. */
26 SymbolDefTreeNode *parent;
27 mlir::SetVector<SymbolDefTreeNode *> children;
28
29 SymbolDefTreeNode(mlir::SymbolOpInterface symbolDefOp) : symbolDef(symbolDefOp), parent(nullptr) {
30 assert(symbolDef && "must have a SymbolOpInterface node");
31 }
32
33 // Used only for creating the root node in the tree.
34 SymbolDefTreeNode() : symbolDef(nullptr), parent(nullptr) {}
35
37 void addChild(SymbolDefTreeNode *node);
38
39 // Provide access to private members.
40 friend class SymbolDefTree;
41
42public:
45 mlir::SymbolOpInterface getOp() const { return symbolDef; }
46
48 const SymbolDefTreeNode *getParent() const { return parent; }
49
51 bool hasChildren() const { return !children.empty(); }
52 size_t numChildren() const { return children.size(); }
53
55 using child_iterator = mlir::SetVector<SymbolDefTreeNode *>::const_iterator;
56 child_iterator begin() const { return children.begin(); }
57 child_iterator end() const { return children.end(); }
58
60 inline llvm::iterator_range<child_iterator> childIter() const {
61 return llvm::make_range(begin(), end());
62 }
63
65 std::string toString() const;
66 void print(llvm::raw_ostream &os) const;
67};
68
73 using NodeMapT = llvm::MapVector<mlir::SymbolOpInterface, std::unique_ptr<SymbolDefTreeNode>>;
74
76 NodeMapT nodes;
77
80
82 class NodeIterator final
83 : public llvm::mapped_iterator<
84 NodeMapT::const_iterator, SymbolDefTreeNode *(*)(const NodeMapT::value_type &)> {
85 static SymbolDefTreeNode *unwrap(const NodeMapT::value_type &value) {
86 return value.second.get();
87 }
88
89 public:
91 NodeIterator(NodeMapT::const_iterator it)
92 : llvm::mapped_iterator<
93 NodeMapT::const_iterator, SymbolDefTreeNode *(*)(const NodeMapT::value_type &)>(
94 it, &unwrap
95 ) {}
96 };
97
100 SymbolDefTreeNode *getOrAddNode(mlir::SymbolOpInterface symbolDef, SymbolDefTreeNode *parentNode);
101
102 void buildTree(mlir::SymbolOpInterface symbolOp, SymbolDefTreeNode *parentNode);
103
104public:
105 SymbolDefTree(mlir::SymbolOpInterface root);
106
108 const SymbolDefTreeNode *lookupNode(mlir::SymbolOpInterface symbolOp) const;
109
111 const SymbolDefTreeNode *getRoot() const { return &root; }
112
114 size_t size() const { return nodes.size(); }
115
117 using iterator = NodeIterator;
118 iterator begin() const { return nodes.begin(); }
119 iterator end() const { return nodes.end(); }
120
121 inline llvm::iterator_range<iterator> nodeIter() const {
122 return llvm::make_range(begin(), end());
123 }
124
126 inline void dump() const { print(llvm::errs()); }
127 void print(llvm::raw_ostream &os) const;
128
130 void dumpToDotFile(std::string filename = "") const;
131};
132
133} // namespace llzk
134
135namespace llvm {
136// Provide graph traits for traversing SymbolDefTree using standard graph traversals.
137
138template <> struct GraphTraits<const llzk::SymbolDefTreeNode *> {
140 static NodeRef getEntryNode(NodeRef node) { return node; }
141
144 static ChildIteratorType child_begin(NodeRef node) { return node->begin(); }
145 static ChildIteratorType child_end(NodeRef node) { return node->end(); }
146};
147
148template <>
149struct GraphTraits<const llzk::SymbolDefTree *>
150 : public GraphTraits<const llzk::SymbolDefTreeNode *> {
152
154 static NodeRef getEntryNode(GraphType g) { return g->getRoot(); }
155
158 static nodes_iterator nodes_begin(GraphType g) { return g->begin(); }
159 static nodes_iterator nodes_end(GraphType g) { return g->end(); }
160
162 static unsigned size(GraphType g) { return g->size(); }
163};
164
165// Provide graph traits for printing SymbolDefTree using dot graph printer.
166template <> struct DOTGraphTraits<const llzk::SymbolDefTreeNode *> : public DefaultDOTGraphTraits {
169
170 DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
171
172 std::string getNodeLabel(NodeRef n, GraphType) { return n->toString(); }
173};
174
175template <>
176struct DOTGraphTraits<const llzk::SymbolDefTree *>
177 : public DOTGraphTraits<const llzk::SymbolDefTreeNode *> {
178
179 DOTGraphTraits(bool isSimple = false) : DOTGraphTraits<NodeRef>(isSimple) {}
180
181 static std::string getGraphName(GraphType) { return "Symbol Def Tree"; }
182
183 std::string getNodeLabel(NodeRef n, GraphType g) {
185 }
186};
187
188} // namespace llvm
child_iterator begin() const
size_t numChildren() const
child_iterator end() const
mlir::SetVector< SymbolDefTreeNode * >::const_iterator child_iterator
Iterator over the children of this node.
bool hasChildren() const
Returns true if this node has any child edges.
std::string toString() const
Print the node in a human readable format.
llvm::iterator_range< child_iterator > childIter() const
Range over child nodes.
const SymbolDefTreeNode * getParent() const
Returns the parent node in the tree. The root node will return nullptr.
mlir::SymbolOpInterface getOp() const
Returns the Symbol operation referenced by this node.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
NodeIterator iterator
An iterator over the nodes of the tree.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
llvm::iterator_range< iterator > nodeIter() const
iterator end() const
iterator begin() const
void dump() const
Dump the tree in a human readable format.
SymbolDefTree(mlir::SymbolOpInterface root)
const SymbolDefTreeNode * getRoot() const
Returns the symbolic (i.e., no associated op) root node of the tree.
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
size_t size() const
Return total number of nodes in the tree.
std::string getNodeLabel(NodeRef n, GraphType g)
llzk::SymbolDefTreeNode::child_iterator ChildIteratorType
ChildIteratorType/begin/end - Allow iteration over all nodes in the graph.
static ChildIteratorType child_begin(NodeRef node)
static ChildIteratorType child_end(NodeRef node)
static nodes_iterator nodes_begin(GraphType g)
static NodeRef getEntryNode(GraphType g)
The entry node into the graph is the external node.
static unsigned size(GraphType g)
Return total number of nodes in the graph.
static nodes_iterator nodes_end(GraphType g)
llzk::SymbolDefTree::iterator nodes_iterator
nodes_iterator/begin/end - Allow iteration over all nodes in the graph.