LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolUseGraph.h
Go to the documentation of this file.
1//===-- SymbolUseGraph.h ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
13
14#include <mlir/IR/BuiltinOps.h>
15#include <mlir/IR/SymbolTable.h>
16
17#include <llvm/ADT/GraphTraits.h>
18#include <llvm/ADT/MapVector.h>
19#include <llvm/ADT/SetVector.h>
20#include <llvm/ADT/SmallPtrSet.h>
21#include <llvm/Support/DOTGraphTraits.h>
22
23#include <utility>
24
25namespace llzk {
26
27class SymbolUseGraphNode {
28 using OpSet = llvm::SmallPtrSet<mlir::Operation *, 3>;
29
30 mlir::ModuleOp symbolPathRoot;
31 mlir::SymbolRefAttr symbolPath;
32 OpSet opsThatUseTheSymbol;
33 bool isStructConstParam;
34
35 /* Tree structure. The SymbolUseGraph owns the nodes so just pointers here. */
37 mlir::SetVector<SymbolUseGraphNode *> predecessors;
39 mlir::SetVector<SymbolUseGraphNode *> successors;
40
41 SymbolUseGraphNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path)
42 : symbolPathRoot(pathRoot), symbolPath(path), isStructConstParam(false) {
43 assert(pathRoot && "'pathRoot' cannot be nullptr");
44 assert(path && "'path' cannot be nullptr");
45 }
46
48 SymbolUseGraphNode() : symbolPathRoot(nullptr), symbolPath(nullptr), isStructConstParam(false) {}
49
51 static bool isRealNodeImpl(const SymbolUseGraphNode *node) { return node->symbolPath != nullptr; }
52
54 void addSuccessor(SymbolUseGraphNode *node);
55
57 void removeSuccessor(SymbolUseGraphNode *node);
58
59 // Provide access to private members.
60 friend class SymbolUseGraph;
61
62public:
65 bool isRealNode() const { return isRealNodeImpl(this); }
66
68 mlir::ModuleOp getSymbolPathRoot() const { return symbolPathRoot; }
69
71 mlir::SymbolRefAttr getSymbolPath() const { return symbolPath; }
72
74 const OpSet &getUserOps() const { return opsThatUseTheSymbol; }
75
77 bool isStructParam() const { return isStructConstParam; }
78
80 bool hasPredecessor() const {
81 return llvm::find_if(predecessors, isRealNodeImpl) != predecessors.end();
82 }
83 size_t numPredecessors() const { return llvm::count_if(predecessors, isRealNodeImpl); }
84
86 bool hasSuccessor() const {
87 return llvm::find_if(successors, isRealNodeImpl) != successors.end();
88 }
89 size_t numSuccessors() const { return llvm::count_if(successors, isRealNodeImpl); }
90
92 using iterator = llvm::filter_iterator<
93 mlir::SetVector<SymbolUseGraphNode *>::const_iterator, bool (*)(const SymbolUseGraphNode *)>;
94
95 inline iterator predecessors_begin() const { return predecessorIter().begin(); }
96 inline iterator predecessors_end() const { return predecessorIter().end(); }
97 inline iterator successors_begin() const { return successorIter().begin(); }
98 inline iterator successors_end() const { return successorIter().end(); }
99
101 llvm::iterator_range<iterator> predecessorIter() const {
102 return llvm::make_filter_range(predecessors, isRealNodeImpl);
103 }
104
106 llvm::iterator_range<iterator> successorIter() const {
107 return llvm::make_filter_range(successors, isRealNodeImpl);
108 }
109
110 mlir::FailureOr<SymbolLookupResultUntyped>
111 lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing = true) const;
112
114 std::string toString(bool showLocations = false) const;
115 void print(llvm::raw_ostream &os, bool showLocations = false, std::string locationLinePrefix = "")
116 const;
117};
118
123 using NodeMapKeyT = std::pair<mlir::ModuleOp, mlir::SymbolRefAttr>;
125 using NodeMapT = llvm::MapVector<NodeMapKeyT, std::unique_ptr<SymbolUseGraphNode>>;
126
128 NodeMapT nodes;
129
130 // The singleton artificial (i.e., no associated op) root/head and tail nodes of the graph. Every
131 // newly created SymbolUseGraphNode is initially a successor of the root node until a real
132 // successor (if any) is added. Similarly, all leaf nodes in the graph have the tail as successor.
133 //
134 // Implementation note: An actual SymbolUseGraphNode is used instead of lists of head/tail nodes
135 // because the GraphTraits implementations require a single entry node. These nodes are not added
136 // to the `nodes` set and should be transparent to users of this graph (other than through the
137 // GraphTraits `getEntryNode()` function implementations).
138 SymbolUseGraphNode root, tail;
139
141 class NodeIterator final
142 : public llvm::mapped_iterator<
143 NodeMapT::const_iterator, SymbolUseGraphNode *(*)(const NodeMapT::value_type &)> {
144 static SymbolUseGraphNode *unwrap(const NodeMapT::value_type &value) {
145 return value.second.get();
146 }
147
148 public:
150 NodeIterator(NodeMapT::const_iterator it)
151 : llvm::mapped_iterator<
152 NodeMapT::const_iterator, SymbolUseGraphNode *(*)(const NodeMapT::value_type &)>(
153 it, &unwrap
154 ) {}
155 };
156
158 SymbolUseGraphNode *getOrAddNode(
159 mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path, SymbolUseGraphNode *predecessorNode
160 );
161
162 SymbolUseGraphNode *getSymbolUserNode(const mlir::SymbolTable::SymbolUse &u);
163 void buildGraph(mlir::SymbolOpInterface symbolOp);
164
165 // Friend declarations for the specializations of GraphTraits
166 friend struct llvm::GraphTraits<const llzk::SymbolUseGraph *>;
167 friend struct llvm::GraphTraits<llvm::Inverse<const llzk::SymbolUseGraph *>>;
168
169public:
170 SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp);
171
173 const SymbolUseGraphNode *lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const;
174
176 const SymbolUseGraphNode *lookupNode(mlir::SymbolOpInterface symbolDef) const;
177
179 size_t size() const { return nodes.size(); }
180
183 roots_iterator roots_begin() const { return root.successors_begin(); }
184 roots_iterator roots_end() const { return root.successors_end(); }
185
187 inline llvm::iterator_range<roots_iterator> rootsIter() const {
188 return llvm::make_range(roots_begin(), roots_end());
189 }
190
192 using iterator = NodeIterator;
193 iterator begin() const { return nodes.begin(); }
194 iterator end() const { return nodes.end(); }
195
197 inline llvm::iterator_range<iterator> nodesIter() const {
198 return llvm::make_range(begin(), end());
199 }
200
202 inline void dump() const { print(llvm::errs()); }
203 void print(llvm::raw_ostream &os) const;
204
206 void dumpToDotFile(std::string filename = "") const;
207};
208
209} // namespace llzk
210
211namespace llvm {
212
213// Provide graph traits for traversing SymbolUseGraph using standard graph traversals.
214template <> struct GraphTraits<const llzk::SymbolUseGraphNode *> {
216 static NodeRef getEntryNode(NodeRef node) { return node; }
217
220 static ChildIteratorType child_begin(NodeRef node) { return node->successors_begin(); }
221 static ChildIteratorType child_end(NodeRef node) { return node->successors_end(); }
222};
223
224template <>
225struct GraphTraits<const llzk::SymbolUseGraph *>
226 : public GraphTraits<const llzk::SymbolUseGraphNode *> {
228
230 static NodeRef getEntryNode(GraphType g) { return &g->root; }
231
234 static nodes_iterator nodes_begin(GraphType g) { return g->begin(); }
235 static nodes_iterator nodes_end(GraphType g) { return g->end(); }
236
238 static unsigned size(GraphType g) { return g->size(); }
239};
240
241// Provide graph traits for traversing SymbolUseGraph using INVERSE graph traversals.
242template <> struct GraphTraits<Inverse<const llzk::SymbolUseGraphNode *>> {
244 static NodeRef getEntryNode(Inverse<NodeRef> node) { return node.Graph; }
245
248 static ChildIteratorType child_end(NodeRef node) { return node->predecessors_end(); }
249};
250
251template <>
252struct GraphTraits<Inverse<const llzk::SymbolUseGraph *>>
253 : public GraphTraits<Inverse<const llzk::SymbolUseGraphNode *>> {
254 using GraphType = Inverse<const llzk::SymbolUseGraph *>;
255
257 static NodeRef getEntryNode(GraphType g) { return &g.Graph->tail; }
258
261 static nodes_iterator nodes_begin(GraphType g) { return g.Graph->begin(); }
262 static nodes_iterator nodes_end(GraphType g) { return g.Graph->end(); }
263
265 static unsigned size(GraphType g) { return g.Graph->size(); }
266};
267
268// Provide graph traits for printing SymbolUseGraph using dot graph printer.
269template <> struct DOTGraphTraits<const llzk::SymbolUseGraphNode *> : public DefaultDOTGraphTraits {
272
273 DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
274
275 std::string getNodeLabel(NodeRef n, GraphType) { return n->toString(true); }
276};
277
278template <>
279struct DOTGraphTraits<const llzk::SymbolUseGraph *>
280 : public DOTGraphTraits<const llzk::SymbolUseGraphNode *> {
281
282 DOTGraphTraits(bool isSimple = false) : DOTGraphTraits<NodeRef>(isSimple) {}
283
284 static std::string getGraphName(GraphType) { return "Symbol Use Graph"; }
285
286 std::string getNodeLabel(NodeRef n, GraphType g) {
288 }
289};
290
291} // namespace llvm
This file defines methods symbol lookup across LLZK operations and included files.
size_t numSuccessors() const
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
bool hasSuccessor() const
Return true if this node has any successors.
iterator predecessors_begin() const
size_t numPredecessors() const
void print(llvm::raw_ostream &os, bool showLocations=false, std::string locationLinePrefix="") const
iterator predecessors_end() const
bool isStructParam() const
Return true iff the symbol is a struct constant parameter name.
bool hasPredecessor() const
Return true if this node has any predecessors.
std::string toString(bool showLocations=false) const
Print the node in a human readable format.
iterator successors_begin() const
llvm::iterator_range< iterator > predecessorIter() const
Range over predecessor nodes.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
const OpSet & getUserOps() const
The set of operations that use the symbol.
iterator successors_end() const
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
llvm::filter_iterator< mlir::SetVector< SymbolUseGraphNode * >::const_iterator, bool(*)(const SymbolUseGraphNode *)> iterator
Iterator over predecessors/successors.
Builds a graph structure representing the relationships between symbols and their uses.
size_t size() const
Return the total number of nodes in the graph.
void dump() const
Dump the graph in a human readable format.
roots_iterator roots_end() const
SymbolUseGraphNode::iterator roots_iterator
Iterator over the root nodes (i.e., nodes that have no predecessors).
llvm::iterator_range< roots_iterator > rootsIter() const
Range over root nodes (i.e., nodes that have no predecessors).
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
iterator begin() const
iterator end() const
roots_iterator roots_begin() const
NodeIterator iterator
An iterator over the nodes of the graph.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::string getNodeLabel(NodeRef n, GraphType g)
llzk::SymbolUseGraph::iterator nodes_iterator
nodes_iterator/begin/end - Allow iteration over all nodes in the graph.
static NodeRef getEntryNode(GraphType g)
The entry node into the inverse graph is the tail node.
static unsigned size(GraphType g)
Return total number of nodes in the graph.
static ChildIteratorType child_end(NodeRef node)
llzk::SymbolUseGraphNode::iterator ChildIteratorType
ChildIteratorType/begin/end - Allow iteration over all nodes in the graph.
static ChildIteratorType child_begin(NodeRef node)
static nodes_iterator nodes_begin(GraphType g)
llzk::SymbolUseGraph::iterator nodes_iterator
nodes_iterator/begin/end - Allow iteration over all nodes in the graph.
static unsigned size(GraphType g)
Return total number of nodes in the graph.
static NodeRef getEntryNode(GraphType g)
The entry node into the graph is the root node.
static nodes_iterator nodes_end(GraphType g)