LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolUseGraph.cpp
Go to the documentation of this file.
1//===-- SymbolUseGraph.cpp --------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13#include "llzk/Util/Constants.h"
17
18#include <mlir/IR/BuiltinOps.h>
19
20#include <llvm/ADT/SmallPtrSet.h>
21#include <llvm/Support/GraphWriter.h>
22
23using namespace mlir;
24
25namespace llzk {
26
27//===----------------------------------------------------------------------===//
28// SymbolUseGraphNode
29//===----------------------------------------------------------------------===//
30
31void SymbolUseGraphNode::addSuccessor(SymbolUseGraphNode *node) {
32 if (this->successors.insert(node)) {
33 node->predecessors.insert(this);
34 }
35}
36
37void SymbolUseGraphNode::removeSuccessor(SymbolUseGraphNode *node) {
38 if (this->successors.remove(node)) {
39 node->predecessors.remove(this);
40 }
41}
42
43FailureOr<SymbolLookupResultUntyped>
44SymbolUseGraphNode::lookupSymbol(SymbolTableCollection &tables, bool reportMissing) const {
45 if (!isRealNode()) {
46 return failure();
47 }
48 Operation *lookupFrom = getSymbolPathRoot().getOperation();
49 auto res = lookupSymbolIn(tables, getSymbolPath(), lookupFrom, lookupFrom, reportMissing);
50 if (succeeded(res) || !reportMissing) {
51 return res;
52 }
53 // This is likely an error in the use graph and not a case that should ever happen.
54 return lookupFrom->emitError().append(
55 "Could not find symbol referenced in UseGraph: ", getSymbolPath()
56 );
57}
58
59//===----------------------------------------------------------------------===//
60// SymbolUseGraph
61//===----------------------------------------------------------------------===//
62
63namespace {
64
65template <typename R>
66R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
67 assert(defOp); // pre-condition
68
69 ModuleOp foundRoot;
70 FailureOr<SymbolRefAttr> path = llzk::getPathFromRoot(defOp, &foundRoot);
71 if (failed(path)) {
72 // This occurs if there is no root module with LANG_ATTR_NAME attribute
73 // or there is an unnamed module between the root module and the symbol.
74 auto diag = defOp.emitError("in SymbolUseGraph, failed to build symbol path");
75 diag.attachNote(defOp.getLoc()).append("for this SymbolOp");
76 diag.report();
77 return nullptr;
78 }
79 return callback(foundRoot, path.value());
80}
81
82} // namespace
83
84SymbolUseGraph::SymbolUseGraph(SymbolOpInterface rootSymbolOp) {
85 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
86 buildGraph(rootSymbolOp);
87}
88
90SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(const SymbolTable::SymbolUse &u) {
91 SymbolOpInterface userSymbol = getSelfOrParentOfType<SymbolOpInterface>(u.getUser());
92 return getPathAndCall<SymbolUseGraphNode *>(userSymbol, [this](ModuleOp r, SymbolRefAttr p) {
93 return this->getOrAddNode(r, p, nullptr);
94 });
95}
96
97void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
98 auto walkFn = [this](Operation *op, bool) {
99 assert(op->hasTrait<OpTrait::SymbolTable>());
100 FailureOr<ModuleOp> opRootModule = llzk::getRootModule(op);
101 if (failed(opRootModule)) {
102 return;
103 }
104
105 SymbolTableCollection tables;
106 if (auto usesOpt = llzk::getSymbolUses(&op->getRegion(0))) {
107 // Create child node for each Symbol use, as successor of the user Symbol op.
108 for (SymbolTable::SymbolUse u : usesOpt.value()) {
109 bool isStructParam = false;
110 SymbolRefAttr symRef = u.getSymbolRef();
111 // Pending [LLZK-272] only a heuristic approach is possible. Check for FlatSymbolRefAttr
112 // where the user is a FieldRefOpInterface or the user is located within a StructDefOp and
113 // append the StructDefOp path with the FlatSymbolRefAttr.
114 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
115 Operation *user = u.getUser();
116 if (auto fref = llvm::dyn_cast<component::FieldRefOpInterface>(user);
117 fref && fref.getFieldNameAttr() == flatSymRef) {
118 symRef = llzk::appendLeaf(fref.getStructType().getNameRef(), flatSymRef);
119 } else if (auto userStruct = getSelfOrParentOfType<component::StructDefOp>(user)) {
120 StringAttr localName = flatSymRef.getAttr();
121 isStructParam = userStruct.hasParamNamed(localName);
122 if (isStructParam || tables.getSymbolTable(userStruct).lookup(localName)) {
123 // If 'flatSymRef' is defined in the SymbolTable for 'userStruct' then it's
124 // a local symbol so prepend the full path of the struct itself.
125 auto parentPath = llzk::getPathFromRoot(userStruct);
126 assert(succeeded(parentPath));
127 symRef = llzk::appendLeaf(parentPath.value(), flatSymRef);
128 }
129 }
130 }
131 auto node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
132 node->isStructConstParam = isStructParam;
133 }
134 }
135 };
136 SymbolTable::walkSymbolTables(symbolOp.getOperation(), true, walkFn);
137
138 // Find all nodes with no successors and add the tail node as successor.
139 for (SymbolUseGraphNode *n : nodesIter()) {
140 if (!n->hasSuccessor()) {
141 n->addSuccessor(&tail);
142 }
143 }
144}
145
146SymbolUseGraphNode *SymbolUseGraph::getOrAddNode(
147 ModuleOp pathRoot, SymbolRefAttr path, SymbolUseGraphNode *predecessorNode
148) {
149 NodeMapKeyT key = std::make_pair(pathRoot, path);
150 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
151 if (!nodeRef) {
152 nodeRef.reset(new SymbolUseGraphNode(pathRoot, path));
153 // When creating a new node, ensure it's attached to the graph, either as successor
154 // to the predecessor node (if given) else as successor to the root node.
155 if (predecessorNode) {
156 predecessorNode->addSuccessor(nodeRef.get());
157 } else {
158 root.addSuccessor(nodeRef.get());
159 }
160 } else if (predecessorNode) {
161 // When the node already exists and an additional predecessor node is given, add the node as a
162 // successor to the given predecessor node and detach from the 'root' (unless it's a self edge).
163 SymbolUseGraphNode *node = nodeRef.get();
164 predecessorNode->addSuccessor(node);
165 if (node != predecessorNode) {
166 root.removeSuccessor(node);
167 }
168 }
169 return nodeRef.get();
170}
171
172const SymbolUseGraphNode *SymbolUseGraph::lookupNode(ModuleOp pathRoot, SymbolRefAttr path) const {
173 NodeMapKeyT key = std::make_pair(pathRoot, path);
174 const auto *it = nodes.find(key);
175 return it == nodes.end() ? nullptr : it->second.get();
176}
177
178const SymbolUseGraphNode *SymbolUseGraph::lookupNode(SymbolOpInterface symbolDef) const {
179 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [this](ModuleOp r, SymbolRefAttr p) {
180 return this->lookupNode(r, p);
181 });
182}
183
184//===----------------------------------------------------------------------===//
185// Printing
186//===----------------------------------------------------------------------===//
187
188std::string SymbolUseGraphNode::toString() const { return buildStringViaPrint(*this); }
189
190namespace {
191
192inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
193 if (root) {
194 FailureOr<SymbolRefAttr> unambiguousRoot = getPathFromTopRoot(root);
195 if (succeeded(unambiguousRoot)) {
196 os << unambiguousRoot.value() << '\n';
197 } else {
198 os << "<<unknown path>>\n";
199 }
200 } else {
201 os << "<<NULL MODULE>>\n";
202 }
203}
204
205} // namespace
206
207void SymbolUseGraphNode::print(llvm::raw_ostream &os) const {
208 os << '\'' << symbolPath << '\'';
209 if (isStructConstParam) {
210 os << " (struct param)";
211 }
212 os << " with root module ";
213 safeAppendPathRoot(os, symbolPathRoot);
214}
215
216void SymbolUseGraph::print(llvm::raw_ostream &os) const {
217 const SymbolUseGraphNode *rootPtr = &this->root;
218
219 // Tracks nodes that have been printed to ensure they are only printed once.
220 SmallPtrSet<SymbolUseGraphNode *, 16> done;
221
222 std::function<void(SymbolUseGraphNode *)> printNode = [rootPtr, &printNode, &done,
223 &os](SymbolUseGraphNode *node) {
224 // Skip if the node has been printed before
225 if (!done.insert(node).second) {
226 return;
227 }
228 // Print the current node
229 os << "// - Node : [" << node << "] ";
230 node->print(os);
231 // Print list of IDs for the predecessors (excluding root) and successors
232 os << "// --- Predecessors : [";
233 llvm::interleaveComma(
234 llvm::make_filter_range(
235 node->predecessorIter(), [rootPtr](SymbolUseGraphNode *n) { return n != rootPtr; }
236 ),
237 os
238 );
239 os << "]\n";
240 os << "// --- Successors : [";
241 llvm::interleaveComma(node->successorIter(), os);
242 os << "]\n";
243 // Recursively print the successors
244 for (SymbolUseGraphNode *c : node->successorIter()) {
245 printNode(c);
246 }
247 };
248
249 os << "// ---- SymbolUseGraph ----\n";
250 for (SymbolUseGraphNode *r : rootPtr->successorIter()) {
251 printNode(r);
252 }
253 os << "// ------------------------\n";
254 assert(done.size() == this->size() && "All nodes were not printed!");
255}
256
257void SymbolUseGraph::dumpToDotFile(std::string filename) const {
259 llvm::WriteGraph(this, "SymbolUseGraph", /*ShortNames*/ false, title, filename);
260}
261
262} // namespace llzk
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
std::string toString() const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
void print(llvm::raw_ostream &os) const
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
std::string buildStringViaPrint(const T &base)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
Definition OpHelpers.h:32
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)