38#include <llvm/ADT/SmallPtrSet.h>
50static bool isPotentiallyUnknownSymbolTable(Operation *op) {
51 return op->getNumRegions() == 1 && !op->getDialect();
55static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
56 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
63static LogicalResult collectValidReferencesFor(
64 Operation *symbol, StringAttr symbolName, Operation *within,
65 SmallVectorImpl<SymbolRefAttr> &results
67 assert(within->isAncestor(symbol) &&
"expected 'within' to be an ancestor");
68 MLIRContext *ctx = symbol->getContext();
70 auto leafRef = FlatSymbolRefAttr::get(symbolName);
71 results.push_back(leafRef);
74 Operation *symbolTableOp = symbol->getParentOp();
75 if (within == symbolTableOp) {
80 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
81 StringAttr symbolNameId = StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
84 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
88 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
89 if (!symbolTableName) {
92 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
94 symbolTableOp = symbolTableOp->getParentOp();
95 if (symbolTableOp == within) {
98 nestedRefs.insert(nestedRefs.begin(), FlatSymbolRefAttr::get(symbolTableName));
106static std::optional<WalkResult> walkSymbolTable(
107 MutableArrayRef<Region> regions, function_ref<std::optional<WalkResult>(Operation *)> callback
109 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
110 while (!worklist.empty()) {
111 for (Operation &op : worklist.pop_back_val()->getOps()) {
112 std::optional<WalkResult> result = callback(&op);
113 if (result != WalkResult::advance()) {
119 if (!op.hasTrait<OpTrait::SymbolTable>()) {
120 for (Region ®ion : op.getRegions()) {
121 worklist.push_back(®ion);
126 return WalkResult::advance();
132walkSymbolRefs(Operation *op, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
134 auto walkFn = [&op, &callback](SymbolRefAttr symbolRef) {
135 if (callback({op, symbolRef}).wasInterrupted()) {
136 return WalkResult::interrupt();
138 return WalkResult::skip();
140 for (Type t : op->getOperandTypes()) {
141 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
142 return WalkResult::interrupt();
145 for (Type t : op->getResultTypes()) {
146 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
147 return WalkResult::interrupt();
150 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(walkFn);
156static std::optional<WalkResult> walkSymbolUses(
157 MutableArrayRef<Region> regions, function_ref<WalkResult(SymbolTable::SymbolUse)> callback
159 return walkSymbolTable(regions, [&](Operation *op) -> std::optional<WalkResult> {
161 if (isPotentiallyUnknownSymbolTable(op)) {
164 return walkSymbolRefs(op, callback);
171static std::optional<WalkResult>
172walkSymbolUses(Operation *
from, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
177 if (isPotentiallyUnknownSymbolTable(
from)) {
182 if (walkSymbolRefs(
from, callback).wasInterrupted()) {
183 return WalkResult::interrupt();
189 if (!
from->hasTrait<OpTrait::SymbolTable>()) {
190 return walkSymbolUses(
from->getRegions(), callback);
192 return WalkResult::advance();
206 std::enable_if_t<!std::is_same<
207 typename llvm::function_traits<CallbackT>::result_t,
void>::value> * =
nullptr>
208 std::optional<WalkResult> walk(CallbackT cback) {
209 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
210 return walkSymbolUses(*region, cback);
212 return walkSymbolUses(limit.get<Operation *>(), cback);
218 std::enable_if_t<std::is_same<
219 typename llvm::function_traits<CallbackT>::result_t,
void>::value> * =
nullptr>
220 std::optional<WalkResult> walk(CallbackT cback) {
221 return walk([=](SymbolTable::SymbolUse
use) {
return cback(
use), WalkResult::advance(); });
226 template <
typename CallbackT> std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
227 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
228 return ::walkSymbolTable(*region, cback);
230 return ::walkSymbolTable(limit.get<Operation *>(), cback);
234 SymbolRefAttr symbol;
237 llvm::PointerUnion<Operation *, Region *> limit;
241static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Operation *limit) {
242 StringAttr symName = SymbolTable::getSymbolName(symbol);
243 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
246 SetVector<Operation *, SmallVector<Operation *, 4>, SmallPtrSet<Operation *, 4>> limitAncestors;
247 Operation *limitAncestor = limit;
250 if (limitAncestor == symbol) {
253 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) {
254 return {{SymbolRefAttr::get(symName), limit}};
259 limitAncestors.insert(limitAncestor);
260 }
while ((limitAncestor = limitAncestor->getParentOp()));
263 Operation *commonAncestor = symbol->getParentOp();
265 if (limitAncestors.count(commonAncestor)) {
268 }
while ((commonAncestor = commonAncestor->getParentOp()));
269 assert(commonAncestor &&
"'limit' and 'symbol' have no common ancestor");
273 SmallVector<SymbolRefAttr, 2> references;
274 bool collectedAllReferences =
275 succeeded(collectValidReferencesFor(symbol, symName, commonAncestor, references));
278 if (commonAncestor == limit) {
279 SmallVector<SymbolScope, 2> scopes;
283 Operation *limitIt = symbol->getParentOp();
284 for (
size_t i = 0, e = references.size(); i != e; ++i, limitIt = limitIt->getParentOp()) {
285 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
286 scopes.push_back({references[i], &limitIt->getRegion(0)});
294 if (!collectedAllReferences) {
297 return {{references.back(), limit}};
300static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Region *limit) {
301 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
305 if (!scopes.empty()) {
306 scopes.back().limit = limit;
311static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Region *limit) {
312 return {{SymbolRefAttr::get(symbol), limit}};
315static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Operation *limit) {
316 SmallVector<SymbolScope, 1> scopes;
317 auto symbolRef = SymbolRefAttr::get(symbol);
318 for (
auto ®ion : limit->getRegions()) {
319 scopes.push_back({symbolRef, ®ion});
326static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
333 if (llvm::isa<FlatSymbolRefAttr>(ref) || ref.getRootReference() != subRef.getRootReference()) {
337 auto refLeafs = ref.getNestedReferences();
338 auto subRefLeafs = subRef.getNestedReferences();
339 return subRefLeafs.size() < refLeafs.size() &&
340 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
351template <
typename FromT>
352static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT
from) {
353 std::vector<SymbolTable::SymbolUse> uses;
354 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
355 uses.push_back(symbolUse);
356 return WalkResult::advance();
358 auto result = walkSymbolUses(
from, walkFn);
359 return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) : std::nullopt;
372 return getSymbolUsesImpl(
from);
375 return getSymbolUsesImpl(MutableArrayRef<Region>(*
from));
384template <
typename SymbolT,
typename IRUnitT>
385static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, IRUnitT *limit) {
386 std::vector<SymbolTable::SymbolUse> uses;
387 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
388 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
389 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) {
390 uses.push_back(symbolUse);
396 return SymbolTable::UseRange(std::move(uses));
406 return getSymbolUsesImpl(symbol,
from);
409 return getSymbolUsesImpl(symbol,
from);
412 return getSymbolUsesImpl(symbol,
from);
415 return getSymbolUsesImpl(symbol,
from);
424template <
typename SymbolT,
typename IRUnitT>
425static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
426 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
428 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
429 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) ? WalkResult::interrupt()
430 : WalkResult::advance();
431 }) != WalkResult::advance()) {
445 return symbolKnownUseEmptyImpl(symbol,
from);
448 return symbolKnownUseEmptyImpl(symbol,
from);
451 return symbolKnownUseEmptyImpl(symbol,
from);
454 return symbolKnownUseEmptyImpl(symbol,
from);
465 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
bool symbolKnownUseEmpty(mlir::StringAttr symbol, mlir::Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...