LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolTableLLZK.cpp
Go to the documentation of this file.
1//===-- SymbolTableLLZK.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's mlir/lib/IR/SymbolTable.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
34//===----------------------------------------------------------------------===//
35
37
38#include <llvm/ADT/SmallPtrSet.h>
39
40using namespace mlir;
41
42//===----------------------------------------------------------------------===//
43// Symbol Use Lists
44//===----------------------------------------------------------------------===//
45
46namespace {
47
50static bool isPotentiallyUnknownSymbolTable(Operation *op) {
51 return op->getNumRegions() == 1 && !op->getDialect();
52}
53
55static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
56 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
57}
58
63static LogicalResult collectValidReferencesFor(
64 Operation *symbol, StringAttr symbolName, Operation *within,
65 SmallVectorImpl<SymbolRefAttr> &results
66) {
67 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
68 MLIRContext *ctx = symbol->getContext();
69
70 auto leafRef = FlatSymbolRefAttr::get(symbolName);
71 results.push_back(leafRef);
72
73 // Early exit for when 'within' is the parent of 'symbol'.
74 Operation *symbolTableOp = symbol->getParentOp();
75 if (within == symbolTableOp) {
76 return success();
77 }
78
79 // Collect references until 'symbolTableOp' reaches 'within'.
80 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
81 StringAttr symbolNameId = StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
82 do {
83 // Each parent of 'symbol' should define a symbol table.
84 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
85 return failure();
86 }
87 // Each parent of 'symbol' should also be a symbol.
88 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
89 if (!symbolTableName) {
90 return failure();
91 }
92 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
93
94 symbolTableOp = symbolTableOp->getParentOp();
95 if (symbolTableOp == within) {
96 break;
97 }
98 nestedRefs.insert(nestedRefs.begin(), FlatSymbolRefAttr::get(symbolTableName));
99 } while (true);
100 return success();
101}
102
106static std::optional<WalkResult> walkSymbolTable(
107 MutableArrayRef<Region> regions, function_ref<std::optional<WalkResult>(Operation *)> callback
108) {
109 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
110 while (!worklist.empty()) {
111 for (Operation &op : worklist.pop_back_val()->getOps()) {
112 std::optional<WalkResult> result = callback(&op);
113 if (result != WalkResult::advance()) {
114 return result;
115 }
116
117 // If this op defines a new symbol table scope, we can't traverse. Any
118 // symbol references nested within 'op' are different semantically.
119 if (!op.hasTrait<OpTrait::SymbolTable>()) {
120 for (Region &region : op.getRegions()) {
121 worklist.push_back(&region);
122 }
123 }
124 }
125 }
126 return WalkResult::advance();
127}
128
131static WalkResult
132walkSymbolRefs(Operation *op, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
133 // This is modified for LLZK.
134 auto walkFn = [&op, &callback](SymbolRefAttr symbolRef) {
135 if (callback({op, symbolRef}).wasInterrupted()) {
136 return WalkResult::interrupt();
137 }
138 return WalkResult::skip(); // Don't walk nested references.
139 };
140 for (Type t : op->getOperandTypes()) {
141 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
142 return WalkResult::interrupt();
143 }
144 }
145 for (Type t : op->getResultTypes()) {
146 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
147 return WalkResult::interrupt();
148 }
149 }
150 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(walkFn);
151}
152
156static std::optional<WalkResult> walkSymbolUses(
157 MutableArrayRef<Region> regions, function_ref<WalkResult(SymbolTable::SymbolUse)> callback
158) {
159 return walkSymbolTable(regions, [&](Operation *op) -> std::optional<WalkResult> {
160 // Check that this isn't a potentially unknown symbol table.
161 if (isPotentiallyUnknownSymbolTable(op)) {
162 return std::nullopt;
163 }
164 return walkSymbolRefs(op, callback);
165 });
166}
167
171static std::optional<WalkResult>
172walkSymbolUses(Operation *from, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
173 // If this operation has regions, and it, as well as its dialect, isn't
174 // registered then conservatively fail. The operation may define a
175 // symbol table, so we can't opaquely know if we should traverse to find
176 // nested uses.
177 if (isPotentiallyUnknownSymbolTable(from)) {
178 return std::nullopt;
179 }
180
181 // Walk the uses on this operation.
182 if (walkSymbolRefs(from, callback).wasInterrupted()) {
183 return WalkResult::interrupt();
184 }
185
186 // Only recurse if this operation is not a symbol table. A symbol table
187 // defines a new scope, so we can't walk the attributes from within the symbol
188 // table op.
189 if (!from->hasTrait<OpTrait::SymbolTable>()) {
190 return walkSymbolUses(from->getRegions(), callback);
191 }
192 return WalkResult::advance();
193}
194
200struct SymbolScope {
204 template <
205 typename CallbackT,
206 std::enable_if_t<!std::is_same<
207 typename llvm::function_traits<CallbackT>::result_t, void>::value> * = nullptr>
208 std::optional<WalkResult> walk(CallbackT cback) {
209 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
210 return walkSymbolUses(*region, cback);
211 }
212 return walkSymbolUses(limit.get<Operation *>(), cback);
213 }
216 template <
217 typename CallbackT,
218 std::enable_if_t<std::is_same<
219 typename llvm::function_traits<CallbackT>::result_t, void>::value> * = nullptr>
220 std::optional<WalkResult> walk(CallbackT cback) {
221 return walk([=](SymbolTable::SymbolUse use) { return cback(use), WalkResult::advance(); });
222 }
223
226 template <typename CallbackT> std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
227 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
228 return ::walkSymbolTable(*region, cback);
229 }
230 return ::walkSymbolTable(limit.get<Operation *>(), cback);
231 }
232
234 SymbolRefAttr symbol;
235
237 llvm::PointerUnion<Operation *, Region *> limit;
238};
239
241static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Operation *limit) {
242 StringAttr symName = SymbolTable::getSymbolName(symbol);
243 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
244
245 // Compute the ancestors of 'limit'.
246 SetVector<Operation *, SmallVector<Operation *, 4>, SmallPtrSet<Operation *, 4>> limitAncestors;
247 Operation *limitAncestor = limit;
248 do {
249 // Check to see if 'symbol' is an ancestor of 'limit'.
250 if (limitAncestor == symbol) {
251 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
252 // doesn't support parent references.
253 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) {
254 return {{SymbolRefAttr::get(symName), limit}};
255 }
256 return {};
257 }
258
259 limitAncestors.insert(limitAncestor);
260 } while ((limitAncestor = limitAncestor->getParentOp()));
261
262 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
263 Operation *commonAncestor = symbol->getParentOp();
264 do {
265 if (limitAncestors.count(commonAncestor)) {
266 break;
267 }
268 } while ((commonAncestor = commonAncestor->getParentOp()));
269 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
270
271 // Compute the set of valid nested references for 'symbol' as far up to the
272 // common ancestor as possible.
273 SmallVector<SymbolRefAttr, 2> references;
274 bool collectedAllReferences =
275 succeeded(collectValidReferencesFor(symbol, symName, commonAncestor, references));
276
277 // Handle the case where the common ancestor is 'limit'.
278 if (commonAncestor == limit) {
279 SmallVector<SymbolScope, 2> scopes;
280
281 // Walk each of the ancestors of 'symbol', calling the compute function for
282 // each one.
283 Operation *limitIt = symbol->getParentOp();
284 for (size_t i = 0, e = references.size(); i != e; ++i, limitIt = limitIt->getParentOp()) {
285 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
286 scopes.push_back({references[i], &limitIt->getRegion(0)});
287 }
288 return scopes;
289 }
290
291 // Otherwise, we just need the symbol reference for 'symbol' that will be
292 // used within 'limit'. This is the last reference in the list we computed
293 // above if we were able to collect all references.
294 if (!collectedAllReferences) {
295 return {};
296 }
297 return {{references.back(), limit}};
298}
299
300static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Region *limit) {
301 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
302
303 // If we collected some scopes to walk, make sure to constrain the one for
304 // limit to the specific region requested.
305 if (!scopes.empty()) {
306 scopes.back().limit = limit;
307 }
308 return scopes;
309}
310
311static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Region *limit) {
312 return {{SymbolRefAttr::get(symbol), limit}};
313}
314
315static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Operation *limit) {
316 SmallVector<SymbolScope, 1> scopes;
317 auto symbolRef = SymbolRefAttr::get(symbol);
318 for (auto &region : limit->getRegions()) {
319 scopes.push_back({symbolRef, &region});
320 }
321 return scopes;
322}
323
326static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
327 if (ref == subRef) {
328 return true;
329 }
330
331 // If the references are not pointer equal, check to see if `subRef` is a
332 // prefix of `ref`.
333 if (llvm::isa<FlatSymbolRefAttr>(ref) || ref.getRootReference() != subRef.getRootReference()) {
334 return false;
335 }
336
337 auto refLeafs = ref.getNestedReferences();
338 auto subRefLeafs = subRef.getNestedReferences();
339 return subRefLeafs.size() < refLeafs.size() &&
340 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
341}
342
343} // namespace
344
345//===----------------------------------------------------------------------===//
346// llzk::getSymbolUses
347
348namespace {
349
351template <typename FromT>
352static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
353 std::vector<SymbolTable::SymbolUse> uses;
354 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
355 uses.push_back(symbolUse);
356 return WalkResult::advance();
357 };
358 auto result = walkSymbolUses(from, walkFn);
359 return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) : std::nullopt;
360}
361
362} // namespace
363
371std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Operation *from) {
372 return getSymbolUsesImpl(from);
373}
374std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Region *from) {
375 return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
376}
377
378//===----------------------------------------------------------------------===//
379// llzk::getSymbolUses
380
381namespace {
382
384template <typename SymbolT, typename IRUnitT>
385static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, IRUnitT *limit) {
386 std::vector<SymbolTable::SymbolUse> uses;
387 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
388 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
389 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) {
390 uses.push_back(symbolUse);
391 }
392 })) {
393 return std::nullopt;
394 }
395 }
396 return SymbolTable::UseRange(std::move(uses));
397}
398
399} // namespace
400
405std::optional<SymbolTable::UseRange> llzk::getSymbolUses(StringAttr symbol, Operation *from) {
406 return getSymbolUsesImpl(symbol, from);
407}
408std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Operation *symbol, Operation *from) {
409 return getSymbolUsesImpl(symbol, from);
410}
411std::optional<SymbolTable::UseRange> llzk::getSymbolUses(StringAttr symbol, Region *from) {
412 return getSymbolUsesImpl(symbol, from);
413}
414std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Operation *symbol, Region *from) {
415 return getSymbolUsesImpl(symbol, from);
416}
417
418//===----------------------------------------------------------------------===//
419// llzk::symbolKnownUseEmpty
420
421namespace {
422
424template <typename SymbolT, typename IRUnitT>
425static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
426 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
427 // Walk all of the symbol uses looking for a reference to 'symbol'.
428 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
429 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) ? WalkResult::interrupt()
430 : WalkResult::advance();
431 }) != WalkResult::advance()) {
432 return false;
433 }
434 }
435 return true;
436}
437
438} // namespace
439
444bool llzk::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
445 return symbolKnownUseEmptyImpl(symbol, from);
446}
447bool llzk::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
448 return symbolKnownUseEmptyImpl(symbol, from);
449}
450bool llzk::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
451 return symbolKnownUseEmptyImpl(symbol, from);
452}
453bool llzk::symbolKnownUseEmpty(Operation *symbol, Region *from) {
454 return symbolKnownUseEmptyImpl(symbol, from);
455}
456
457//===----------------------------------------------------------------------===//
458// llzk::getSymbolName
459
460StringAttr llzk::getSymbolName(Operation *op) {
461 // This is modified for LLZK.
462 // `SymbolTable::getSymbolName(Operation*)` asserts if there is no name (ex: in the case of
463 // ModuleOp where the symbol name is optional) and there's no other way to check if the name
464 // exists so this fully involved retrieval method must be used to return `nullptr` if no name.
465 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
466}
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
bool symbolKnownUseEmpty(mlir::StringAttr symbol, mlir::Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...