LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolUseGraph.cpp
Go to the documentation of this file.
1//===-- SymbolUseGraph.cpp --------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13#include "llzk/Util/Compare.h"
14#include "llzk/Util/Constants.h"
18
19#include <mlir/IR/BuiltinOps.h>
20
21#include <llvm/ADT/SmallPtrSet.h>
22#include <llvm/ADT/SmallSet.h>
23#include <llvm/Support/GraphWriter.h>
24
25using namespace mlir;
26
27namespace llzk {
28
29//===----------------------------------------------------------------------===//
30// SymbolUseGraphNode
31//===----------------------------------------------------------------------===//
32
33void SymbolUseGraphNode::addSuccessor(SymbolUseGraphNode *node) {
34 if (this->successors.insert(node)) {
35 node->predecessors.insert(this);
36 }
37}
38
39void SymbolUseGraphNode::removeSuccessor(SymbolUseGraphNode *node) {
40 if (this->successors.remove(node)) {
41 node->predecessors.remove(this);
42 }
43}
44
45FailureOr<SymbolLookupResultUntyped>
46SymbolUseGraphNode::lookupSymbol(SymbolTableCollection &tables, bool reportMissing) const {
47 if (!isRealNode()) {
48 return failure();
49 }
50 Operation *lookupFrom = getSymbolPathRoot().getOperation();
51 auto res = lookupSymbolIn(tables, getSymbolPath(), lookupFrom, lookupFrom, reportMissing);
52 if (succeeded(res) || !reportMissing) {
53 return res;
54 }
55 // This is likely an error in the use graph and not a case that should ever happen.
56 return lookupFrom->emitError().append(
57 "Could not find symbol referenced in UseGraph: ", getSymbolPath()
58 );
59}
60
61//===----------------------------------------------------------------------===//
62// SymbolUseGraph
63//===----------------------------------------------------------------------===//
64
65namespace {
66
67template <typename R>
68R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
69 assert(defOp); // pre-condition
70
71 ModuleOp foundRoot;
72 FailureOr<SymbolRefAttr> path = llzk::getPathFromRoot(defOp, &foundRoot);
73 if (failed(path)) {
74 // This occurs if there is no root module with LANG_ATTR_NAME attribute
75 // or there is an unnamed module between the root module and the symbol.
76 auto diag = defOp.emitError("in SymbolUseGraph, failed to build symbol path");
77 diag.attachNote(defOp.getLoc()).append("for this SymbolOp");
78 diag.report();
79 return nullptr;
80 }
81 return callback(foundRoot, path.value());
82}
83
84} // namespace
85
86SymbolUseGraph::SymbolUseGraph(SymbolOpInterface rootSymbolOp) {
87 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
88 buildGraph(rootSymbolOp);
89}
90
92SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(const SymbolTable::SymbolUse &u) {
93 SymbolOpInterface userSymbol = getSelfOrParentOfType<SymbolOpInterface>(u.getUser());
94 return getPathAndCall<SymbolUseGraphNode *>(
95 userSymbol, [this, &userSymbol](ModuleOp r, SymbolRefAttr p) {
96 auto n = this->getOrAddNode(r, p, nullptr);
97 n->opsThatUseTheSymbol.insert(userSymbol);
98 return n;
99 }
100 );
101}
102
103void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
104 auto walkFn = [this](Operation *op, bool) {
105 assert(op->hasTrait<OpTrait::SymbolTable>());
106 FailureOr<ModuleOp> opRootModule = llzk::getRootModule(op);
107 if (failed(opRootModule)) {
108 return;
109 }
110
111 SymbolTableCollection tables;
112 if (auto usesOpt = llzk::getSymbolUses(&op->getRegion(0))) {
113 // Create child node for each Symbol use, as successor of the user Symbol op.
114 for (SymbolTable::SymbolUse u : usesOpt.value()) {
115 bool isStructParam = false;
116 Operation *user = u.getUser();
117 SymbolRefAttr symRef = u.getSymbolRef();
118 // Pending [LLZK-272] only a heuristic approach is possible. Check for FlatSymbolRefAttr
119 // where the user is a FieldRefOpInterface or the user is located within a StructDefOp and
120 // append the StructDefOp path with the FlatSymbolRefAttr.
121 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
122 if (auto fref = llvm::dyn_cast<component::FieldRefOpInterface>(user);
123 fref && fref.getFieldNameAttr() == flatSymRef) {
124 symRef = llzk::appendLeaf(fref.getStructType().getNameRef(), flatSymRef);
125 } else if (auto userStruct = getSelfOrParentOfType<component::StructDefOp>(user)) {
126 StringAttr localName = flatSymRef.getAttr();
127 isStructParam = userStruct.hasParamNamed(localName);
128 if (isStructParam || tables.getSymbolTable(userStruct).lookup(localName)) {
129 // If 'flatSymRef' is defined in the SymbolTable for 'userStruct' then it's
130 // a local symbol so prepend the full path of the struct itself.
131 auto parentPath = llzk::getPathFromRoot(userStruct);
132 assert(succeeded(parentPath));
133 symRef = llzk::appendLeaf(parentPath.value(), flatSymRef);
134 }
135 }
136 }
137 auto node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
138 node->isStructConstParam = isStructParam;
139 node->opsThatUseTheSymbol.insert(user);
140 }
141 }
142 };
143 SymbolTable::walkSymbolTables(symbolOp.getOperation(), true, walkFn);
144
145 // Find all nodes with no successors and add the tail node as successor.
146 for (SymbolUseGraphNode *n : nodesIter()) {
147 if (!n->hasSuccessor()) {
148 n->addSuccessor(&tail);
149 }
150 }
151}
152
153SymbolUseGraphNode *SymbolUseGraph::getOrAddNode(
154 ModuleOp pathRoot, SymbolRefAttr path, SymbolUseGraphNode *predecessorNode
155) {
156 NodeMapKeyT key = std::make_pair(pathRoot, path);
157 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
158 if (!nodeRef) {
159 nodeRef.reset(new SymbolUseGraphNode(pathRoot, path));
160 // When creating a new node, ensure it's attached to the graph, either as successor
161 // to the predecessor node (if given) else as successor to the root node.
162 if (predecessorNode) {
163 predecessorNode->addSuccessor(nodeRef.get());
164 } else {
165 root.addSuccessor(nodeRef.get());
166 }
167 } else if (predecessorNode) {
168 // When the node already exists and an additional predecessor node is given, add the node as a
169 // successor to the given predecessor node and detach from the 'root' (unless it's a self edge).
170 SymbolUseGraphNode *node = nodeRef.get();
171 predecessorNode->addSuccessor(node);
172 if (node != predecessorNode) {
173 root.removeSuccessor(node);
174 }
175 }
176 return nodeRef.get();
177}
178
179const SymbolUseGraphNode *SymbolUseGraph::lookupNode(ModuleOp pathRoot, SymbolRefAttr path) const {
180 NodeMapKeyT key = std::make_pair(pathRoot, path);
181 const auto *it = nodes.find(key);
182 return it == nodes.end() ? nullptr : it->second.get();
183}
184
185const SymbolUseGraphNode *SymbolUseGraph::lookupNode(SymbolOpInterface symbolDef) const {
186 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [this](ModuleOp r, SymbolRefAttr p) {
187 return this->lookupNode(r, p);
188 });
189}
190
191//===----------------------------------------------------------------------===//
192// Printing
193//===----------------------------------------------------------------------===//
194
195std::string SymbolUseGraphNode::toString(bool showLocations) const {
196 return buildStringViaPrint(*this, showLocations);
197}
198
199namespace {
200
201inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
202 if (root) {
203 FailureOr<SymbolRefAttr> unambiguousRoot = getPathFromTopRoot(root);
204 if (succeeded(unambiguousRoot)) {
205 os << unambiguousRoot.value() << '\n';
206 } else {
207 os << "<<unknown path>>\n";
208 }
209 } else {
210 os << "<<NULL MODULE>>\n";
211 }
212}
213
214} // namespace
215
217 llvm::raw_ostream &os, bool showLocations, std::string locationLinePrefix
218) const {
219 os << '\'' << symbolPath << '\'';
220 if (isStructConstParam) {
221 os << " (struct param)";
222 }
223 os << " with root module ";
224 safeAppendPathRoot(os, symbolPathRoot);
225 if (showLocations) {
226 // Print the user op locations (sorted for stable output). Printing only the location rather
227 // than the full Operation gives a short (single-line) format that's still useful for human
228 // debugging.
229 llvm::SmallSet<mlir::Location, 3, LocationComparator> locations;
230 for (Operation *user : getUserOps()) {
231 locations.insert(user->getLoc());
232 }
233 for (Location loc : locations) {
234 os << locationLinePrefix << loc << '\n';
235 }
236 }
237}
238
239void SymbolUseGraph::print(llvm::raw_ostream &os) const {
240 const SymbolUseGraphNode *rootPtr = &this->root;
241
242 // Tracks nodes that have been printed to ensure they are only printed once.
243 SmallPtrSet<SymbolUseGraphNode *, 16> done;
244
245 std::function<void(SymbolUseGraphNode *)> printNode = [rootPtr, &printNode, &done,
246 &os](SymbolUseGraphNode *node) {
247 // Skip if the node has been printed before
248 if (!done.insert(node).second) {
249 return;
250 }
251 // Print the current node
252 os << "// - Node : [" << node << "] ";
253 node->print(os, true, "// --- ");
254 // Print list of IDs for the predecessors (excluding root) and successors
255 os << "// --- Predecessors : [";
256 llvm::interleaveComma(
257 llvm::make_filter_range(
258 node->predecessorIter(), [rootPtr](SymbolUseGraphNode *n) { return n != rootPtr; }
259 ),
260 os
261 );
262 os << "]\n";
263 os << "// --- Successors : [";
264 llvm::interleaveComma(node->successorIter(), os);
265 os << "]\n";
266 // Recursively print the successors
267 for (SymbolUseGraphNode *c : node->successorIter()) {
268 printNode(c);
269 }
270 };
271
272 os << "// ---- SymbolUseGraph ----\n";
273 for (SymbolUseGraphNode *r : rootPtr->successorIter()) {
274 printNode(r);
275 }
276 os << "// ------------------------\n";
277 assert(done.size() == this->size() && "All nodes were not printed!");
278}
279
280void SymbolUseGraph::dumpToDotFile(std::string filename) const {
282 llvm::WriteGraph(this, "SymbolUseGraph", /*ShortNames*/ false, title, filename);
283}
284
285} // namespace llzk
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
void print(llvm::raw_ostream &os, bool showLocations=false, std::string locationLinePrefix="") const
std::string toString(bool showLocations=false) const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
const OpSet & getUserOps() const
The set of operations that use the symbol.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
Definition OpHelpers.h:32
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)