LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SymbolLookup.h
Go to the documentation of this file.
1//===-- SymbolLookup.h - Symbol Lookup Functions ----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
16#pragma once
17
18#include "llzk/Util/Constants.h"
19
20#include <mlir/IR/BuiltinOps.h>
21#include <mlir/IR/Operation.h>
22#include <mlir/IR/OwningOpRef.h>
23
24#include <llvm/ADT/StringRef.h>
25
26#include <variant>
27#include <vector>
28
29namespace llzk {
30
32 std::shared_ptr<std::pair<mlir::OwningOpRef<mlir::ModuleOp>, mlir::SymbolTableCollection>>;
33
35public:
36 SymbolLookupResultUntyped() : op(nullptr) {}
37 SymbolLookupResultUntyped(mlir::Operation *opPtr) : op(opPtr) {}
38
40 mlir::Operation *operator->();
41 mlir::Operation &operator*();
42 mlir::Operation &operator*() const;
43 mlir::Operation *get();
44 mlir::Operation *get() const;
45
47 operator bool() const;
48
50 std::vector<llvm::StringRef> getIncludeSymNames() const { return includeSymNameStack; }
51
53 bool viaInclude() const { return !includeSymNameStack.empty(); }
54
55 mlir::SymbolTableCollection *getSymbolTableCache() {
56 if (managedResources) {
57 return &managedResources->second;
58 } else {
59 return nullptr;
60 }
61 }
62
64 void manage(mlir::OwningOpRef<mlir::ModuleOp> &&ptr, mlir::SymbolTableCollection &&tables);
65
67 void trackIncludeAsName(llvm::StringRef includeOpSymName);
68
69 bool operator==(const SymbolLookupResultUntyped &rhs) const { return op == rhs.op; }
70
71private:
72 mlir::Operation *op;
75 ManagedResources managedResources;
77 std::vector<llvm::StringRef> includeSymNameStack;
78
79 friend class Within;
80};
81
82template <typename T> class SymbolLookupResult {
83public:
84 SymbolLookupResult(SymbolLookupResultUntyped &&innerRes) : inner(std::move(innerRes)) {}
85
88 T operator->() { return llvm::dyn_cast<T>(*inner); }
89 T operator*() { return llvm::dyn_cast<T>(*inner); }
90 const T operator*() const { return llvm::dyn_cast<T>(*inner); }
91 T get() { return llvm::dyn_cast<T>(inner.get()); }
92 T get() const { return llvm::dyn_cast<T>(inner.get()); }
93
94 operator bool() const { return inner && llvm::isa<T>(*inner); }
95
97 std::vector<llvm::StringRef> getIncludeSymNames() const { return inner.getIncludeSymNames(); }
98
100 bool viaInclude() const { return inner.viaInclude(); }
101
102 bool operator==(const SymbolLookupResult<T> &rhs) const { return inner == rhs.inner; }
103
104private:
106
107 friend class Within;
108};
109
110class Within {
111public:
113 Within() : from(nullptr) {}
115 Within(mlir::Operation *op) : from(op) { assert(op && "cannot lookup within nullptr"); }
117 Within(SymbolLookupResultUntyped &&res) : from(std::move(res)) {}
119 template <typename T> Within(SymbolLookupResult<T> &&res) : Within(std::move(res.inner)) {}
120
121 Within(const Within &) = delete;
122 Within(Within &&other) noexcept : from(std::move(other.from)) {}
123 Within &operator=(const Within &) = delete;
124 Within &operator=(Within &&) noexcept;
125
126 inline static Within root() { return Within(); }
127
128 mlir::FailureOr<SymbolLookupResultUntyped> lookup(
129 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
130 bool reportMissing = true
131 ) &&;
132
133private:
134 std::variant<mlir::Operation *, SymbolLookupResultUntyped> from;
135};
136
137inline mlir::FailureOr<SymbolLookupResultUntyped> lookupSymbolIn(
138 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin,
139 mlir::Operation *origin, bool reportMissing = true
140) {
141 return std::move(lookupWithin).lookup(tables, symbol, origin, reportMissing);
142}
143
144inline mlir::FailureOr<SymbolLookupResultUntyped> lookupTopLevelSymbol(
145 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
146 bool reportMissing = true
147) {
148 return Within().lookup(tables, symbol, origin, reportMissing);
149}
150
151template <typename T>
152inline mlir::FailureOr<SymbolLookupResult<T>> lookupSymbolIn(
153 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin,
154 mlir::Operation *origin, bool reportMissing = true
155) {
156 auto found = lookupSymbolIn(tables, symbol, std::move(lookupWithin), origin, reportMissing);
157 if (mlir::failed(found)) {
158 return mlir::failure(); // lookupSymbolIn() already emits a sufficient error message
159 }
160 // Keep a copy of the op ptr in case we need it for displaying diagnostics
161 mlir::Operation *op = found->get();
162 // ... since the untyped result gets moved here into a typed result.
163 SymbolLookupResult<T> ret(std::move(*found));
164 if (!ret) {
165 if (reportMissing) {
166 return origin->emitError() << "symbol \"" << symbol << "\" references a '" << op->getName()
167 << "' but expected a '" << T::getOperationName() << '\'';
168 } else {
169 return mlir::failure();
170 }
171 }
172 return ret;
173}
174
175template <typename T>
176inline mlir::FailureOr<SymbolLookupResult<T>> lookupTopLevelSymbol(
177 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
178 bool reportMissing = true
179) {
180 return lookupSymbolIn<T>(tables, symbol, Within(), origin, reportMissing);
181}
182
183} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
mlir::SymbolTableCollection * getSymbolTableCache()
bool viaInclude() const
Return 'true' if at least one IncludeOp was traversed to load this result.
void manage(mlir::OwningOpRef< mlir::ModuleOp > &&ptr, mlir::SymbolTableCollection &&tables)
Adds a pointer to the set of resources the result has to manage the lifetime of.
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
void trackIncludeAsName(llvm::StringRef includeOpSymName)
Adds the symbol name from the IncludeOp that caused the module to be loaded.
bool operator==(const SymbolLookupResultUntyped &rhs) const
SymbolLookupResultUntyped(mlir::Operation *opPtr)
mlir::Operation * operator->()
Access the internal operation.
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
const T operator*() const
SymbolLookupResult(SymbolLookupResultUntyped &&innerRes)
bool viaInclude() const
Return 'true' if at least one IncludeOp was traversed to load this result.
bool operator==(const SymbolLookupResult< T > &rhs) const
T operator->()
Access the internal operation as type T.
mlir::FailureOr< SymbolLookupResultUntyped > lookup(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true) &&
Within(const Within &)=delete
Within()
Lookup within the top-level (root) module.
Within(SymbolLookupResultUntyped &&res)
Lookup within the Operation of the given result and transfer managed resources.
Within(Within &&other) noexcept
Within & operator=(const Within &)=delete
Within(mlir::Operation *op)
Lookup within the given Operation (cannot be nullptr)
static Within root()
Within(SymbolLookupResult< T > &&res)
Lookup within the Operation of the given result and transfer managed resources.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::shared_ptr< std::pair< mlir::OwningOpRef< mlir::ModuleOp >, mlir::SymbolTableCollection > > ManagedResources
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)