LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
CallGraph.cpp
Go to the documentation of this file.
1//===-- CallGraph.cpp - LLZK-specific call graph implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// The contents of this file are adapted from llvm/lib/Analysis/CallGraph.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
19
20#include <mlir/Analysis/CallGraph.h>
21#include <mlir/IR/Operation.h>
22#include <mlir/IR/SymbolTable.h>
23#include <mlir/Interfaces/CallInterfaces.h>
24
25#include <llvm/ADT/DepthFirstIterator.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/ErrorHandling.h>
28
29namespace llzk {
30
31using namespace function;
32
33//===----------------------------------------------------------------------===//
34// CallGraphNode
35//===----------------------------------------------------------------------===//
36
38bool CallGraphNode::isExternal() const { return !callableRegion; }
39
42mlir::Region *CallGraphNode::getCallableRegion() const {
43 assert(!isExternal() && "the external node has no callable region");
44 return callableRegion;
45}
46
48 return mlir::dyn_cast<FuncDefOp>(getCallableRegion()->getParentOp());
49}
50
53void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
54 assert(isExternal() && "abstract edges are only valid on external nodes");
55 addEdge(node, Edge::Kind::Abstract);
56}
57
59void CallGraphNode::addCallEdge(CallGraphNode *node) { addEdge(node, Edge::Kind::Call); }
60
62void CallGraphNode::addChildEdge(CallGraphNode *child) { addEdge(child, Edge::Kind::Child); }
63
66 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
67}
68
70void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
71 edges.insert({this, node, kind});
72}
73
74//===----------------------------------------------------------------------===//
75// CallGraph
76//===----------------------------------------------------------------------===//
77
80static void computeCallGraph(
81 mlir::Operation *op, CallGraph &cg, mlir::SymbolTableCollection &symbolTable,
82 CallGraphNode *parentNode, bool resolveCalls
83) {
84 if (mlir::CallOpInterface call = mlir::dyn_cast<mlir::CallOpInterface>(op)) {
85 // If there is no parent node, we ignore this operation. Even if this
86 // operation was a call, there would be no callgraph node to attribute it
87 // to.
88 if (resolveCalls && parentNode) {
89 parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
90 }
91 return;
92 }
93
94 // Compute the callgraph nodes and edges for each of the nested operations.
95 if (mlir::CallableOpInterface callable = mlir::dyn_cast<mlir::CallableOpInterface>(op)) {
96 if (auto *callableRegion = callable.getCallableRegion()) {
97 parentNode = cg.getOrAddNode(callableRegion, parentNode);
98 } else {
99 return;
100 }
101 }
102
103 for (mlir::Region &region : op->getRegions()) {
104 for (mlir::Operation &nested : region.getOps()) {
105 computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
106 }
107 }
108}
109
110CallGraph::CallGraph(mlir::Operation *op)
111 : externalCallerNode(/*callableRegion=*/nullptr),
112 unknownCalleeNode(/*callableRegion=*/nullptr) {
113 // Make two passes over the graph, one to compute the callables and one to
114 // resolve the calls. We split these up as we may have nested callable objects
115 // that need to be reserved before the calls.
116 mlir::SymbolTableCollection symbolTable;
117 computeCallGraph(
118 op, *this, symbolTable, /*parentNode=*/nullptr,
119 /*resolveCalls=*/false
120 );
121 computeCallGraph(
122 op, *this, symbolTable, /*parentNode=*/nullptr,
123 /*resolveCalls=*/true
124 );
125}
126
128CallGraphNode *CallGraph::getOrAddNode(mlir::Region *region, CallGraphNode *parentNode) {
129 assert(
130 region && mlir::isa<mlir::CallableOpInterface>(region->getParentOp()) &&
131 "expected parent operation to be callable"
132 );
133 std::unique_ptr<CallGraphNode> &node = nodes[region];
134 if (!node) {
135 node.reset(new CallGraphNode(region));
136
137 // Add this node to the given parent node if necessary.
138 if (parentNode) {
139 parentNode->addChildEdge(node.get());
140 } else {
141 // Otherwise, connect all callable nodes to the external node, this allows
142 // for conservatively including all callable nodes within the graph.
143 // FIXME This isn't correct, this is only necessary for callable nodes
144 // that *could* be called from external sources. This requires extending
145 // the interface for callables to check if they may be referenced
146 // externally.
147 externalCallerNode.addAbstractEdge(node.get());
148 }
149 }
150 return node.get();
151}
152
155CallGraphNode *CallGraph::lookupNode(mlir::Region *region) const {
156 const auto *it = nodes.find(region);
157 return it == nodes.end() ? nullptr : it->second.get();
158}
159
164 mlir::CallOpInterface call, mlir::SymbolTableCollection &symbolTable
165) const {
166 auto res = llzk::resolveCallable<FuncDefOp>(symbolTable, call);
167 if (mlir::succeeded(res)) {
168 if (auto *node = lookupNode(res->get().getCallableRegion())) {
169 return node;
170 }
171 }
172
173 return getUnknownCalleeNode();
174}
175
178 // Erase any children of this node first.
179 if (node->hasChildren()) {
180 for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node)) {
181 if (edge.isChild()) {
182 eraseNode(edge.getTarget());
183 }
184 }
185 }
186 // Erase any edges to this node from any other nodes.
187 for (auto &it : nodes) {
188 it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
189 return edge.getTarget() == node;
190 });
191 }
192 nodes.erase(node->getCallableRegion());
193}
194
195//===----------------------------------------------------------------------===//
196// Printing
197
199void CallGraph::dump() const { print(llvm::errs()); }
200void CallGraph::print(llvm::raw_ostream &os) const {
201 os << "// ---- CallGraph ----\n";
202
203 // Functor used to output the name for the given node.
204 auto emitNodeName = [&](const CallGraphNode *node) {
205 if (node == getExternalCallerNode()) {
206 os << "<External-Caller-Node>";
207 return;
208 }
209 if (node == getUnknownCalleeNode()) {
210 os << "<Unknown-Callee-Node>";
211 return;
212 }
213
214 auto *callableRegion = node->getCallableRegion();
215 auto *parentOp = callableRegion->getParentOp();
216 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
217 << callableRegion->getRegionNumber();
218 auto attrs = parentOp->getAttrDictionary();
219 if (!isNullOrEmpty(attrs)) {
220 os << " : " << attrs;
221 }
222 };
223
224 for (auto &nodeIt : nodes) {
225 const CallGraphNode *node = nodeIt.second.get();
226
227 // Dump the header for this node.
228 os << "// - Node : ";
229 emitNodeName(node);
230 os << "\n";
231
232 // Emit each of the edges.
233 for (auto &edge : *node) {
234 os << "// -- ";
235 if (edge.isCall()) {
236 os << "Call";
237 } else if (edge.isChild()) {
238 os << "Child";
239 }
240
241 os << "-Edge : ";
242 emitNodeName(edge.getTarget());
243 os << "\n";
244 }
245 os << "//\n";
246 }
247
248 os << "// -- SCCs --\n";
249
250 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
251 os << "// - SCC : \n";
252 for (auto &node : scc) {
253 os << "// -- Node :";
254 emitNodeName(node);
255 os << "\n";
256 }
257 os << "\n";
258 }
259
260 os << "// -------------------\n";
261}
262
263} // namespace llzk
This class represents a directed edge between two nodes in the callgraph.
Definition CallGraph.h:42
bool isChild() const
Returns true if this edge represents a Child edge.
Definition CallGraph.h:69
CallGraphNode * getTarget() const
Returns the target node for this edge.
Definition CallGraph.h:76
This is a simple port of the mlir::CallGraphNode with llzk::CallGraph as a friend class,...
Definition CallGraph.h:39
bool isExternal() const
Returns true if this node is an external node.
Definition CallGraph.cpp:38
mlir::Region * getCallableRegion() const
Returns the callable region this node represents.
Definition CallGraph.cpp:42
void addChildEdge(CallGraphNode *child)
Adds a reference edge to the given child node.
Definition CallGraph.cpp:62
bool hasChildren() const
Returns true if this node has any child edges.
Definition CallGraph.cpp:65
void addCallEdge(CallGraphNode *node)
Add an outgoing call edge from this node.
Definition CallGraph.cpp:59
void addAbstractEdge(CallGraphNode *node)
Adds an abstract reference edge to the given node.
Definition CallGraph.cpp:53
iterator end() const
Definition CallGraph.h:125
llzk::function::FuncDefOp getCalledFunction() const
Returns the called function that the callable region represents.
Definition CallGraph.cpp:47
This is a port of mlir::CallGraph that has been adapted to use the custom symbol lookup helpers (see ...
Definition CallGraph.h:167
CallGraph(mlir::Operation *op)
CallGraphNode * getExternalCallerNode() const
Return the callgraph node representing an external caller.
Definition CallGraph.h:199
void dump() const
Dump the graph in a human readable format.
void print(llvm::raw_ostream &os) const
void eraseNode(CallGraphNode *node)
Erase the given node from the callgraph.
CallGraphNode * resolveCallable(mlir::CallOpInterface call, mlir::SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(mlir::Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
CallGraphNode * getUnknownCalleeNode() const
Return the callgraph node representing an indirect callee.
Definition CallGraph.h:204
CallGraphNode * getOrAddNode(mlir::Region *region, CallGraphNode *parentNode)
Get or add a call graph node for the given region.
bool isNullOrEmpty(mlir::ArrayAttr a)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.