16#include <llvm/ADT/StringMap.h>
18#include <clang/Basic/FileManager.h>
19#include <clang/Basic/LangOptions.h>
20#include <clang/Basic/SourceManager.h>
21#include <clang/Lex/Lexer.h>
27llvm::cl::OptionCategory
28 OpGenCat(
"Options for -gen-op-capi-header, -gen-op-capi-impl, and -gen-op-capi-tests");
33 "The dialect name to use for this group of ops. "
34 "Must match across header, implementation, and test generation."
42 "The prefix to use for generated C API function names. "
43 "Default is 'mlir'. Must match across header, implementation, and test generation."
45 llvm::cl::init(
"mlir"), llvm::cl::cat(
OpGenCat)
49 "gen-isa", llvm::cl::desc(
"Generate IsA checks"), llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
53 "gen-op-build", llvm::cl::desc(
"Generate operation build(..) functions"), llvm::cl::init(
true),
58 "gen-operand-getters", llvm::cl::desc(
"Generate operand getters for operations"),
59 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
63 "gen-operand-setters", llvm::cl::desc(
"Generate operand setters for operations"),
64 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
68 "gen-attribute-getters", llvm::cl::desc(
"Generate attribute getters for operations"),
69 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
73 "gen-attribute-setters", llvm::cl::desc(
"Generate attribute setters for operations"),
74 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
78 "gen-region-getters", llvm::cl::desc(
"Generate region getters for operations"),
79 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
83 "gen-result-getters", llvm::cl::desc(
"Generate result getters for operations"),
84 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
88 "gen-type-attr-get", llvm::cl::desc(
"Generate get functions for types and attributes"),
89 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
93 "gen-parameter-getters", llvm::cl::desc(
"Generate parameter getters for types and attributes"),
94 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
98 "gen-extra-class-methods",
99 llvm::cl::desc(
"Generate C API wrappers for methods in `extraClassDeclaration`"),
100 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
116 std::unique_ptr<DiagnosticsEngine>
diags;
127 FileSystemOptions fileSystemOpts;
128 fileMgr =
new FileManager(fileSystemOpts);
135 : impl(std::make_unique<
Impl>()) {
137 llvm::errs() <<
"Warning: ClangLexerContext created with empty source\n";
142 std::unique_ptr<llvm::MemoryBuffer> buffer = llvm::MemoryBuffer::getMemBuffer(
source, bufferName);
144 llvm::errs() <<
"Error: Failed to create memory buffer for ClangLexerContext\n";
148 FileID fileID = impl->sourceMgr->createFileID(std::move(buffer), SrcMgr::C_User);
149 llvm::MemoryBufferRef bufferRef = impl->sourceMgr->getBufferOrFake(fileID);
151 if (bufferRef.getBufferSize() == 0 && !
source.empty()) {
152 llvm::errs() <<
"Error: Failed to get buffer from source manager in ClangLexerContext\n";
157 impl->lexer = std::make_unique<Lexer>(fileID, bufferRef, *impl->sourceMgr, impl->langOpts);
159 impl->lexer->SetCommentRetentionState(
true);
160 lexer = impl->lexer.get();
164 assert(lexer &&
"Lexer not initialized - check isValid() before calling getLexer()");
170 impl && impl->sourceMgr &&
171 "SourceManager not initialized - check isValid() before calling getSourceManager()"
173 return *impl->sourceMgr;
178static inline bool isAccessModifier(StringRef tokenText) {
179 return tokenText ==
"private" || tokenText ==
"public" || tokenText ==
"protected";
187 std::vector<Token> tokens;
189 for (Token tok; !lexer.LexFromRawLexer(tok);) {
190 if (tok.is(tok::eof)) {
193 tokens.push_back(tok);
199static inline std::string getDocumentation(
200 size_t returnTypeStart,
const std::vector<Token> &tokens,
const SourceManager &sourceMgr
203 for (
size_t j = returnTypeStart; j > 0; --j) {
204 Token curr = tokens[j - 1];
205 if (curr.is(tok::comment)) {
206 StringRef comment(sourceMgr.getCharacterData(curr.getLocation()), curr.getLength());
207 comment.consume_front(
"///");
208 comment.consume_front(
"//");
209 if (comment.consume_front(
"/*")) {
210 comment.consume_back(
"*/");
212 comment = comment.trim();
215 if (!comment.empty()) {
222 }
else if (!curr.is(tok::unknown)) {
225 if (curr.isOneOf(tok::semi, tok::r_brace, tok::l_brace)) {
228 if (curr.is(tok::raw_identifier) && isAccessModifier(curr.getRawIdentifier())) {
253updateAccessLevel(
size_t &i,
const std::vector<Token> &tokens,
AccessLevel ¤tAccess) {
254 if (i + 1 < tokens.size() && tokens[i + 1].is(tok::colon)) {
255 if (tokens[i].is(tok::raw_identifier)) {
256 StringRef name = tokens[i].getRawIdentifier();
257 if (name ==
"private") {
261 }
else if (name ==
"public") {
265 }
else if (name ==
"protected") {
285static std::string extractReturnType(
286 size_t i,
const std::vector<Token> &tokens,
const SourceManager &sourceMgr,
287 size_t &returnTypeStart
289 std::string returnType;
291 bool isStaticMethod =
false;
294 for (
size_t j = i; j > 0; --j) {
295 Token curr = tokens[j - 1];
297 if (curr.isOneOf(tok::semi, tok::r_brace)) {
303 if (curr.is(tok::raw_identifier)) {
304 StringRef text = curr.getRawIdentifier();
305 if (text ==
"static") {
306 isStaticMethod =
true;
309 if (tokens[j].is(tok::colon) && isAccessModifier(text)) {
311 returnTypeStart = j + 1;
312 assert(returnTypeStart < tokens.size());
319 if (isStaticMethod) {
325 while (tokens[returnTypeStart].is(tok::comment)) {
327 assert(returnTypeStart <= i);
331 returnType.reserve(32);
332 llvm::raw_string_ostream returnTypeStream(returnType);
334 for (
size_t j = returnTypeStart; j < i; ++j) {
336 if (tokens[j].is(tok::comment)) {
340 StringRef tokenText(sourceMgr.getCharacterData(tokens[j].getLocation()), tokens[j].getLength());
343 if (tokens[j].is(tok::raw_identifier) && isAccessModifier(tokenText)) {
345 if (j + 1 < i && tokens[j + 1].is(tok::colon)) {
352 if (tokens[j].is(tok::raw_identifier) && tokenText ==
"return") {
364 if (tokens[j].is(tok::colon)) {
366 if (j == 0 || j + 1 >= i || !tokens[j - 1].is(tok::colon)) {
372 if (!returnType.empty() && !returnType.ends_with(
"::") && tokenText !=
"::" &&
373 !tokenText.starts_with(
"::")) {
374 returnTypeStream <<
' ';
376 returnTypeStream << tokenText;
380 return StringRef(returnType).trim().str();
393static bool parseMethodParameters(
394 size_t i,
const std::vector<Token> &tokens,
const SourceManager &sourceMgr,
395 size_t &closeParenIdx,
bool &hasParameters, std::vector<MethodParameter> ¶meters
397 const size_t tokenCount = tokens.size();
398 closeParenIdx = tokenCount;
399 hasParameters =
false;
404 size_t parenDepth = 1;
405 for (
size_t j = i + 2; j < tokenCount; ++j) {
406 if (tokens[j].is(tok::l_paren)) {
408 }
else if (tokens[j].is(tok::r_paren)) {
410 if (parenDepth == 0) {
415 std::vector<Token> paramTokens;
416 for (
size_t k = i + 2; k < j; ++k) {
417 if (k >= tokenCount) {
420 if (!tokens[k].is(tok::comment)) {
421 paramTokens.push_back(tokens[k]);
424 const size_t paramTokenCount = paramTokens.size();
427 if (paramTokenCount == 1) {
428 StringRef paramToken(
429 sourceMgr.getCharacterData(paramTokens[0].getLocation()), paramTokens[0].getLength()
431 if (paramToken !=
"void") {
432 hasParameters =
true;
434 }
else if (paramTokenCount > 1) {
435 hasParameters =
true;
440 std::string currentParamType;
441 std::string currentParamName;
442 bool inDefaultValue =
false;
444 for (
size_t k = 0; k < paramTokenCount; ++k) {
446 if (paramTokens[k].is(tok::comma)) {
448 if (!currentParamType.empty() && !currentParamName.empty()) {
449 parameters.push_back(
MethodParameter(currentParamType, currentParamName));
451 currentParamType.clear();
452 currentParamName.clear();
453 inDefaultValue =
false;
457 if (inDefaultValue) {
461 if (paramTokens[k].is(tok::equal)) {
462 inDefaultValue =
true;
467 sourceMgr.getCharacterData(paramTokens[k].getLocation()), paramTokens[k].getLength()
472 if (paramTokens[k].is(tok::raw_identifier)) {
473 if (k + 1 == paramTokenCount ||
474 (k + 1 < paramTokenCount && paramTokens[k + 1].isOneOf(tok::comma, tok::equal))) {
475 currentParamName = tokenText.str();
481 llvm::raw_string_ostream paramTypeStream(currentParamType);
482 if (!currentParamType.empty() && tokenText !=
"*" && tokenText !=
"&" &&
483 tokenText !=
"::" && !tokenText.starts_with(
"::") &&
484 !StringRef(currentParamType).ends_with(
"::")) {
485 paramTypeStream <<
' ';
487 paramTypeStream << tokenText;
488 currentParamType = paramTypeStream.str();
492 if (!currentParamType.empty() && !currentParamName.empty()) {
493 parameters.push_back(
MethodParameter(currentParamType, currentParamName));
514checkConstAndFindEnd(
size_t closeParenIdx,
const std::vector<Token> &tokens,
size_t &endIdx) {
515 bool isConst =
false;
516 endIdx = closeParenIdx + 1;
518 while (endIdx < tokens.size()) {
519 Token curr = tokens[endIdx];
520 if (curr.isOneOf(tok::semi, tok::l_brace)) {
523 if (curr.is(tok::raw_identifier) && curr.getRawIdentifier() ==
"const") {
551 if (extraDecl.empty()) {
558 llvm::errs() <<
"Error: Failed to create lexer context for parseExtraMethods\n";
563 llvm::StringMap<std::optional<ExtraMethod>> methods;
566 const std::vector<Token> tokens = tokenize(lexerCtx);
567 const size_t tokenCount = tokens.size();
574 for (
size_t i = 0; i < tokenCount; ++i) {
576 if (tokens[i].is(tok::comment)) {
581 if (updateAccessLevel(i, tokens, currentAccess)) {
592 if (i + 1 < tokenCount && tokens[i + 1].is(tok::l_paren) && tokens[i].is(tok::raw_identifier)) {
593 StringRef methodName = tokens[i].getRawIdentifier();
601 size_t returnTypeStart = 0;
602 std::string returnType = extractReturnType(i, tokens, sourceMgr, returnTypeStart);
605 if (returnType.empty()) {
610 size_t closeParenIdx = tokenCount;
611 bool hasParameters =
false;
612 std::vector<MethodParameter> parameters;
613 if (!parseMethodParameters(i, tokens, sourceMgr, closeParenIdx, hasParameters, parameters)) {
620 bool isConst = checkConstAndFindEnd(closeParenIdx, tokens, endIdx);
623 if (!returnType.empty() && !methodName.empty()) {
624 if (methods.contains(methodName)) {
625 warnSkipped(methodName,
"C API does not support method overloading");
626 methods[methodName] = std::nullopt;
631 method.
documentation = getDocumentation(returnTypeStart, tokens, sourceMgr);
635 methods[methodName] = std::make_optional(method);
645 return llvm::to_vector(
647 llvm::make_filter_range(methods, [](
const auto &p) {
return p.second.has_value(); }),
648 [](
const auto &p) {
return p.second.value(); }
655 if (cppType == typeName) {
660 StringRef prefix = cppType;
661 prefix.consume_front(
"::");
662 if (prefix.consume_front(
"mlir::")) {
663 return prefix == typeName;
671 cppType = cppType.trim();
675 return std::make_optional(cppType.str());
680 return std::make_optional(
"int64_t");
685 if (cppType.ends_with(
" *") || cppType.ends_with(
"*")) {
686 size_t starPos = cppType.rfind(
'*');
687 if (starPos != StringRef::npos) {
688 StringRef baseType = cppType.substr(0, starPos).trim();
690 return std::make_optional(
"MlirAsmState");
693 return std::make_optional(
"MlirBytecodeWriterConfig");
696 return std::make_optional(
"MlirContext");
699 return std::make_optional(
"MlirDialect");
702 return std::make_optional(
"MlirDialectRegistry");
705 return std::make_optional(
"MlirOperation");
708 return std::make_optional(
"MlirBlock");
711 return std::make_optional(
"MlirOpOperand");
714 return std::make_optional(
"MlirOpPrintingFlags");
717 return std::make_optional(
"MlirRegion");
720 return std::make_optional(
"MlirSymbolTable");
723 llvm::errs() <<
"Error: Failed to parse pointer type: " << cppType <<
'\n';
730 return std::make_optional(
"MlirAttribute");
733 return std::make_optional(
"MlirIdentifier");
736 return std::make_optional(
"MlirLocation");
739 return std::make_optional(
"MlirModule");
742 return std::make_optional(
"MlirType");
745 return std::make_optional(
"MlirValue");
749 return std::make_optional(
"MlirAffineExpr");
753 return std::make_optional(
"MlirAffineMap");
757 return std::make_optional(
"MlirIntegerSet");
761 return std::make_optional(
"MlirTypeID");
766 return std::make_optional(
"MlirStringRef");
769 return std::make_optional(
"MlirLogicalResult");
773 if (cppType.ends_with(
"Type")) {
774 return std::make_optional(
"MlirType");
776 if (cppType.ends_with(
"Attr")) {
777 return std::make_optional(
"MlirAttribute");
779 if (cppType.ends_with(
"Op")) {
780 return std::make_optional(
"MlirOperation");
789 assert(!
isArrayRefType(cppType) &&
"must check `isArrayRefType()` outside");
792 if (capiTypeOpt.has_value()) {
793 return capiTypeOpt.value();
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
std::string mapCppTypeToCapiType(StringRef cppType)
AccessLevel
Access level tracking for C++ class declarations.
SmallVector< ExtraMethod > parseExtraMethods(StringRef extraDecl)
Parse method declarations from extraClassDeclaration using Clang's Lexer.
bool matchesMLIRClass(StringRef cppType, StringRef typeName)
Check if a C++ type matches an MLIR type pattern.
llvm::cl::OptionCategory OpGenCat
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
bool isCppModifierKeyword(mlir::StringRef tokenText)
Check if a token text represents a C++ modifier/specifier keyword.
bool isAPIntType(mlir::StringRef cppType)
Check if a C++ type is APInt.
llvm::cl::opt< std::string > FunctionPrefix
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
void warnSkipped(const S &methodName, const std::string &message)
Print warning about skipping a function.
bool isCppLanguageConstruct(mlir::StringRef methodName)
Check if a method name represents a C++ control flow keyword or language construct.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated documentation
RAII wrapper for Clang lexer infrastructure.
clang::SourceManager & getSourceManager() const
Get the source manager instance.
bool isValid() const
Check if the lexer was successfully created.
ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName="input")
Construct a lexer context for the given source code.
clang::Lexer & getLexer() const
Get the lexer instance.
IntrusiveRefCntPtr< DiagnosticOptions > diagOpts
Diagnostic options for configuring diagnostics.
IntrusiveRefCntPtr< FileManager > fileMgr
File manager for handling virtual files.
IntrusiveRefCntPtr< DiagnosticIDs > diagIDs
Diagnostic IDs for error reporting.
LangOptions langOpts
C++ language options for lexer configuration.
std::unique_ptr< SourceManager > sourceMgr
Source manager for tracking file locations.
std::unique_ptr< Lexer > lexer
The actual lexer instance.
std::unique_ptr< DiagnosticsEngine > diags
Diagnostics engine for handling errors and warnings.
Structure to represent a parameter in a parsed method signature from an extraClassDeclaration