LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
CommonCAPIGen.cpp
Go to the documentation of this file.
1//===- CommonCAPIGen.cpp - Common utilities for C API generation ----------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Shared command-line options for all CAPI generators (ops, attrs, types)
11//
12//===----------------------------------------------------------------------===//
13
14#include "CommonCAPIGen.h"
15
16#include <llvm/ADT/StringMap.h>
17
18#include <clang/Basic/FileManager.h>
19#include <clang/Basic/LangOptions.h>
20#include <clang/Basic/SourceManager.h>
21#include <clang/Lex/Lexer.h>
22#include <optional>
23
24using namespace mlir;
25using namespace clang;
26
27llvm::cl::OptionCategory
28 OpGenCat("Options for -gen-op-capi-header, -gen-op-capi-impl, and -gen-op-capi-tests");
29
30llvm::cl::opt<std::string> DialectName(
31 "dialect",
32 llvm::cl::desc(
33 "The dialect name to use for this group of ops. "
34 "Must match across header, implementation, and test generation."
35 ),
36 llvm::cl::cat(OpGenCat)
37);
38
39llvm::cl::opt<std::string> FunctionPrefix(
40 "prefix",
41 llvm::cl::desc(
42 "The prefix to use for generated C API function names. "
43 "Default is 'mlir'. Must match across header, implementation, and test generation."
44 ),
45 llvm::cl::init("mlir"), llvm::cl::cat(OpGenCat)
46);
47
48llvm::cl::opt<bool> GenIsA(
49 "gen-isa", llvm::cl::desc("Generate IsA checks"), llvm::cl::init(true), llvm::cl::cat(OpGenCat)
50);
51
52llvm::cl::opt<bool> GenOpBuild(
53 "gen-op-build", llvm::cl::desc("Generate operation build(..) functions"), llvm::cl::init(true),
54 llvm::cl::cat(OpGenCat)
55);
56
57llvm::cl::opt<bool> GenOpOperandGetters(
58 "gen-operand-getters", llvm::cl::desc("Generate operand getters for operations"),
59 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
60);
61
62llvm::cl::opt<bool> GenOpOperandSetters(
63 "gen-operand-setters", llvm::cl::desc("Generate operand setters for operations"),
64 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
65);
66
67llvm::cl::opt<bool> GenOpAttributeGetters(
68 "gen-attribute-getters", llvm::cl::desc("Generate attribute getters for operations"),
69 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
70);
71
72llvm::cl::opt<bool> GenOpAttributeSetters(
73 "gen-attribute-setters", llvm::cl::desc("Generate attribute setters for operations"),
74 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
75);
76
77llvm::cl::opt<bool> GenOpRegionGetters(
78 "gen-region-getters", llvm::cl::desc("Generate region getters for operations"),
79 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
80);
81
82llvm::cl::opt<bool> GenOpResultGetters(
83 "gen-result-getters", llvm::cl::desc("Generate result getters for operations"),
84 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
85);
86
87llvm::cl::opt<bool> GenTypeOrAttrGet(
88 "gen-type-attr-get", llvm::cl::desc("Generate get functions for types and attributes"),
89 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
90);
91
92llvm::cl::opt<bool> GenTypeOrAttrParamGetters(
93 "gen-parameter-getters", llvm::cl::desc("Generate parameter getters for types and attributes"),
94 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
95);
96
97llvm::cl::opt<bool> GenExtraClassMethods(
98 "gen-extra-class-methods",
99 llvm::cl::desc("Generate C API wrappers for methods in `extraClassDeclaration`"),
100 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
101);
102
103//===----------------------------------------------------------------------===//
104// ClangLexerContext Implementation
105//===----------------------------------------------------------------------===//
108 LangOptions langOpts;
110 IntrusiveRefCntPtr<FileManager> fileMgr;
112 IntrusiveRefCntPtr<DiagnosticIDs> diagIDs;
114 IntrusiveRefCntPtr<DiagnosticOptions> diagOpts;
116 std::unique_ptr<DiagnosticsEngine> diags;
118 std::unique_ptr<SourceManager> sourceMgr;
120 std::unique_ptr<Lexer> lexer;
121
122 Impl() : diagIDs(new DiagnosticIDs()), diagOpts(new DiagnosticOptions()) {
123 // Enable C++ language features for lexing
124 langOpts.CPlusPlus = true;
125 langOpts.CPlusPlus11 = true;
126
127 FileSystemOptions fileSystemOpts;
128 fileMgr = new FileManager(fileSystemOpts);
129 diags = std::make_unique<DiagnosticsEngine>(diagIDs, diagOpts);
130 sourceMgr = std::make_unique<SourceManager>(*diags, *fileMgr);
131 }
132};
133
134ClangLexerContext::ClangLexerContext(StringRef source, StringRef bufferName)
135 : impl(std::make_unique<Impl>()) {
136 if (source.empty()) {
137 llvm::errs() << "Warning: ClangLexerContext created with empty source\n";
138 return;
139 }
140
141 // Create a memory buffer for the input
142 std::unique_ptr<llvm::MemoryBuffer> buffer = llvm::MemoryBuffer::getMemBuffer(source, bufferName);
143 if (!buffer) {
144 llvm::errs() << "Error: Failed to create memory buffer for ClangLexerContext\n";
145 return;
146 }
147
148 FileID fileID = impl->sourceMgr->createFileID(std::move(buffer), SrcMgr::C_User);
149 llvm::MemoryBufferRef bufferRef = impl->sourceMgr->getBufferOrFake(fileID);
150
151 if (bufferRef.getBufferSize() == 0 && !source.empty()) {
152 llvm::errs() << "Error: Failed to get buffer from source manager in ClangLexerContext\n";
153 return;
154 }
155
156 // Create the lexer
157 impl->lexer = std::make_unique<Lexer>(fileID, bufferRef, *impl->sourceMgr, impl->langOpts);
158 // Enable comment parsing for extraClassDeclaration method extraction
159 impl->lexer->SetCommentRetentionState(true);
160 lexer = impl->lexer.get();
161}
162
164 assert(lexer && "Lexer not initialized - check isValid() before calling getLexer()");
165 return *lexer;
166}
167
169 assert(
170 impl && impl->sourceMgr &&
171 "SourceManager not initialized - check isValid() before calling getSourceManager()"
172 );
173 return *impl->sourceMgr;
174}
175
176namespace {
177
178static inline bool isAccessModifier(StringRef tokenText) {
179 return tokenText == "private" || tokenText == "public" || tokenText == "protected";
180}
181
185static inline std::vector<Token> tokenize(const ClangLexerContext &lexerCtx) {
186 Lexer &lexer = lexerCtx.getLexer();
187 std::vector<Token> tokens;
188 tokens.reserve(128); // Reasonable default
189 for (Token tok; !lexer.LexFromRawLexer(tok);) {
190 if (tok.is(tok::eof)) {
191 break;
192 }
193 tokens.push_back(tok);
194 }
195 return tokens;
196}
197
199static inline std::string getDocumentation(
200 size_t returnTypeStart, const std::vector<Token> &tokens, const SourceManager &sourceMgr
201) {
202 std::string documentation;
203 for (size_t j = returnTypeStart; j > 0; --j) {
204 Token curr = tokens[j - 1];
205 if (curr.is(tok::comment)) {
206 StringRef comment(sourceMgr.getCharacterData(curr.getLocation()), curr.getLength());
207 comment.consume_front("///");
208 comment.consume_front("//");
209 if (comment.consume_front("/*")) {
210 comment.consume_back("*/");
211 }
212 comment = comment.trim();
213
214 // Trim whitespace
215 if (!comment.empty()) {
216 if (!documentation.empty()) {
217 documentation = comment.str() + " " + documentation;
218 } else {
219 documentation = comment.str();
220 }
221 }
222 } else if (!curr.is(tok::unknown)) {
223 // Stop looking backwards when we hit a non-comment, non-whitespace token
224 // that could be part of another declaration
225 if (curr.isOneOf(tok::semi, tok::r_brace, tok::l_brace)) {
226 break;
227 }
228 if (curr.is(tok::raw_identifier) && isAccessModifier(curr.getRawIdentifier())) {
229 break;
230 }
231 }
232 }
233 return documentation;
234}
235
236} // namespace
237
238//===----------------------------------------------------------------------===//
239// Method Parsing Implementation
240//===----------------------------------------------------------------------===//
241
244
252static bool
253updateAccessLevel(size_t &i, const std::vector<Token> &tokens, AccessLevel &currentAccess) {
254 if (i + 1 < tokens.size() && tokens[i + 1].is(tok::colon)) {
255 if (tokens[i].is(tok::raw_identifier)) {
256 StringRef name = tokens[i].getRawIdentifier();
257 if (name == "private") {
258 currentAccess = AccessLevel::Private;
259 i++; // extra skip for the colon
260 return true;
261 } else if (name == "public") {
262 currentAccess = AccessLevel::Public;
263 i++; // extra skip for the colon
264 return true;
265 } else if (name == "protected") {
266 currentAccess = AccessLevel::Protected;
267 i++; // extra skip for the colon
268 return true;
269 }
270 }
271 }
272 return false;
273}
274
285static std::string extractReturnType(
286 size_t i, const std::vector<Token> &tokens, const SourceManager &sourceMgr,
287 size_t &returnTypeStart
288) {
289 std::string returnType;
290 returnTypeStart = 0;
291 bool isStaticMethod = false;
292
293 // Look backwards for return type start, stopping at declaration boundaries
294 for (size_t j = i; j > 0; --j) {
295 Token curr = tokens[j - 1];
296 // Semicolon or right brace indicates lookback has reached the end of a prior declaration.
297 if (curr.isOneOf(tok::semi, tok::r_brace)) {
298 returnTypeStart = j;
299 break;
300 }
301 // Check for "static" or access modifiers in the return type lookback (both appear as
302 // `raw_identifier` in raw token stream).
303 if (curr.is(tok::raw_identifier)) {
304 StringRef text = curr.getRawIdentifier();
305 if (text == "static") {
306 isStaticMethod = true;
307 break;
308 }
309 if (tokens[j].is(tok::colon) && isAccessModifier(text)) {
310 // In this case, `returnTypeStart` must be after the colon.
311 returnTypeStart = j + 1;
312 assert(returnTypeStart < tokens.size());
313 break;
314 }
315 }
316 }
317
318 // Skip static methods (for now)
319 if (isStaticMethod) {
320 return "";
321 }
322
323 // Adjust `returnTypeStart` for potential comment tokens. Skip as many
324 // sequential comments as needed.
325 while (tokens[returnTypeStart].is(tok::comment)) {
326 returnTypeStart++;
327 assert(returnTypeStart <= i);
328 }
329
330 // Build return type from tokens, skipping modifiers and comments.
331 returnType.reserve(32); // Reasonable default for most type names
332 llvm::raw_string_ostream returnTypeStream(returnType);
333
334 for (size_t j = returnTypeStart; j < i; ++j) {
335 // Skip comments - they should be extracted as documentation, not part of the return type
336 if (tokens[j].is(tok::comment)) {
337 continue;
338 }
339
340 StringRef tokenText(sourceMgr.getCharacterData(tokens[j].getLocation()), tokens[j].getLength());
341
342 // Skip access specifiers (e.g., "private", "public", "protected").
343 if (tokens[j].is(tok::raw_identifier) && isAccessModifier(tokenText)) {
344 // If followed by a colon, skip that too
345 if (j + 1 < i && tokens[j + 1].is(tok::colon)) {
346 j++; // Skip the colon too
347 }
348 continue;
349 }
350
351 // Skip common implementation keywords that indicate we're in code, not a declaration.
352 if (tokens[j].is(tok::raw_identifier) && tokenText == "return") {
353 // This indicates we've hit implementation code, stop parsing
354 returnType.clear();
355 break;
356 }
357
358 // Skip modifiers and language keywords that shouldn't be in the return type
359 if (tokens[j].is(tok::raw_identifier) && isCppModifierKeyword(tokenText)) {
360 continue;
361 }
362
363 // Skip standalone colons (from lookback to access specifiers)
364 if (tokens[j].is(tok::colon)) {
365 // Only skip if it's not part of ::
366 if (j == 0 || j + 1 >= i || !tokens[j - 1].is(tok::colon)) {
367 continue;
368 }
369 }
370
371 // Add spacing between tokens (but not around ::)
372 if (!returnType.empty() && !returnType.ends_with("::") && tokenText != "::" &&
373 !tokenText.starts_with("::")) {
374 returnTypeStream << ' ';
375 }
376 returnTypeStream << tokenText;
377 }
378
379 // Trim possible whitespace
380 return StringRef(returnType).trim().str();
381}
382
393static bool parseMethodParameters(
394 size_t i, const std::vector<Token> &tokens, const SourceManager &sourceMgr,
395 size_t &closeParenIdx, bool &hasParameters, std::vector<MethodParameter> &parameters
396) {
397 const size_t tokenCount = tokens.size();
398 closeParenIdx = tokenCount;
399 hasParameters = false;
400 parameters.clear();
401
402 // Initialize parenDepth to 1 to account for the opening '(' at tokens[i+1]
403 // Start scanning from i+2 (the first token after the opening paren)
404 size_t parenDepth = 1;
405 for (size_t j = i + 2; j < tokenCount; ++j) {
406 if (tokens[j].is(tok::l_paren)) {
407 parenDepth++;
408 } else if (tokens[j].is(tok::r_paren)) {
409 parenDepth--;
410 if (parenDepth == 0) {
411 closeParenIdx = j;
412
413 // Parse parameters between '(' and ')'
414 // Parameters follow the pattern: type name [, type name ...]
415 std::vector<Token> paramTokens;
416 for (size_t k = i + 2; k < j; ++k) {
417 if (k >= tokenCount) {
418 break;
419 }
420 if (!tokens[k].is(tok::comment)) {
421 paramTokens.push_back(tokens[k]);
422 }
423 }
424 const size_t paramTokenCount = paramTokens.size();
425
426 // Check if we have actual parameters (excluding just "void")
427 if (paramTokenCount == 1) {
428 StringRef paramToken(
429 sourceMgr.getCharacterData(paramTokens[0].getLocation()), paramTokens[0].getLength()
430 );
431 if (paramToken != "void") {
432 hasParameters = true;
433 }
434 } else if (paramTokenCount > 1) {
435 hasParameters = true;
436 }
437
438 // Parse individual parameters
439 if (hasParameters) {
440 std::string currentParamType;
441 std::string currentParamName;
442 bool inDefaultValue = false;
443
444 for (size_t k = 0; k < paramTokenCount; ++k) {
445 // Check for end of current parameter
446 if (paramTokens[k].is(tok::comma)) {
447 // Add the current parameter if valid
448 if (!currentParamType.empty() && !currentParamName.empty()) {
449 parameters.push_back(MethodParameter(currentParamType, currentParamName));
450 }
451 currentParamType.clear();
452 currentParamName.clear();
453 inDefaultValue = false;
454 continue;
455 }
456 // Skip tokens that are part of the default value
457 if (inDefaultValue) {
458 continue;
459 }
460 // Check for '=' which indicates start of default value
461 if (paramTokens[k].is(tok::equal)) {
462 inDefaultValue = true;
463 continue;
464 }
465
466 StringRef tokenText(
467 sourceMgr.getCharacterData(paramTokens[k].getLocation()), paramTokens[k].getLength()
468 );
469
470 // Identifier token could be part of the type or the parameter name.
471 // Simple heuristic: last identifier before comma, equal, or end is the name
472 if (paramTokens[k].is(tok::raw_identifier)) {
473 if (k + 1 == paramTokenCount ||
474 (k + 1 < paramTokenCount && paramTokens[k + 1].isOneOf(tok::comma, tok::equal))) {
475 currentParamName = tokenText.str();
476 continue;
477 }
478 }
479
480 // Other identifiers and other tokens (keywords, ::, *, &, etc.) are part of type.
481 llvm::raw_string_ostream paramTypeStream(currentParamType);
482 if (!currentParamType.empty() && tokenText != "*" && tokenText != "&" &&
483 tokenText != "::" && !tokenText.starts_with("::") &&
484 !StringRef(currentParamType).ends_with("::")) {
485 paramTypeStream << ' ';
486 }
487 paramTypeStream << tokenText;
488 currentParamType = paramTypeStream.str();
489 }
490
491 // Add the last parameter if valid
492 if (!currentParamType.empty() && !currentParamName.empty()) {
493 parameters.push_back(MethodParameter(currentParamType, currentParamName));
494 }
495 }
496
497 return true;
498 }
499 }
500 }
501
502 // Couldn't find closing paren
503 return false;
504}
505
513static bool
514checkConstAndFindEnd(size_t closeParenIdx, const std::vector<Token> &tokens, size_t &endIdx) {
515 bool isConst = false;
516 endIdx = closeParenIdx + 1;
517
518 while (endIdx < tokens.size()) {
519 Token curr = tokens[endIdx];
520 if (curr.isOneOf(tok::semi, tok::l_brace)) {
521 break;
522 }
523 if (curr.is(tok::raw_identifier) && curr.getRawIdentifier() == "const") {
524 isConst = true;
525 }
526 endIdx++;
527 }
528
529 return isConst;
530}
531
550SmallVector<ExtraMethod> parseExtraMethods(StringRef extraDecl) {
551 if (extraDecl.empty()) {
552 return {};
553 }
554
555 // Use ClangLexerContext for simplified setup
556 const ClangLexerContext lexerCtx(extraDecl, "extraClassDecl");
557 if (!lexerCtx.isValid()) {
558 llvm::errs() << "Error: Failed to create lexer context for parseExtraMethods\n";
559 return {};
560 }
561
562 // Store methods uniqued by name to detect and skip overloads (duplicate method names).
563 llvm::StringMap<std::optional<ExtraMethod>> methods;
564
565 // Parse tokens to find method declarations
566 const std::vector<Token> tokens = tokenize(lexerCtx);
567 const size_t tokenCount = tokens.size();
568 const SourceManager &sourceMgr = lexerCtx.getSourceManager();
569
570 // Track current access level to avoid generating C API wrappers for private functions. Code
571 // generated by `mlir-tblgen` puts the extra declarations in the public section by default.
572 AccessLevel currentAccess = AccessLevel::Public;
573
574 for (size_t i = 0; i < tokenCount; ++i) {
575 // Skip comments (they'll be extracted separately)
576 if (tokens[i].is(tok::comment)) {
577 continue;
578 }
579
580 // Check for access specifier changes (e.g., "private:", "public:", "protected:").
581 if (updateAccessLevel(i, tokens, currentAccess)) {
582 continue;
583 }
584
585 // Skip private and protected methods - no need to generate C API wrappers
586 if (currentAccess != AccessLevel::Public) {
587 continue;
588 }
589
590 // Look for pattern: [modifiers] <return_type> <identifier> '(' [params] ')' [const] ';'
591 // Look for an identifier followed by '('
592 if (i + 1 < tokenCount && tokens[i + 1].is(tok::l_paren) && tokens[i].is(tok::raw_identifier)) {
593 StringRef methodName = tokens[i].getRawIdentifier();
594
595 // Skip control flow keywords and other language constructs that use parentheses
596 if (isCppLanguageConstruct(methodName)) {
597 continue;
598 }
599
600 // Extract return type (everything before method name)
601 size_t returnTypeStart = 0;
602 std::string returnType = extractReturnType(i, tokens, sourceMgr, returnTypeStart);
603
604 // Skip static methods (return type is empty if static)
605 if (returnType.empty()) {
606 continue;
607 }
608
609 // Parse method parameters
610 size_t closeParenIdx = tokenCount;
611 bool hasParameters = false;
612 std::vector<MethodParameter> parameters;
613 if (!parseMethodParameters(i, tokens, sourceMgr, closeParenIdx, hasParameters, parameters)) {
614 // Couldn't find closing paren, skip this method
615 continue;
616 }
617
618 // Check for 'const' and find declaration end
619 size_t endIdx;
620 bool isConst = checkConstAndFindEnd(closeParenIdx, tokens, endIdx);
621
622 // Create method struct
623 if (!returnType.empty() && !methodName.empty()) {
624 if (methods.contains(methodName)) {
625 warnSkipped(methodName, "C API does not support method overloading");
626 methods[methodName] = std::nullopt;
627 } else {
628 ExtraMethod method;
629 method.returnType = returnType;
630 method.methodName = methodName;
631 method.documentation = getDocumentation(returnTypeStart, tokens, sourceMgr);
632 method.isConst = isConst;
633 method.hasParameters = hasParameters;
634 method.parameters = parameters;
635 methods[methodName] = std::make_optional(method);
636 }
637 }
638
639 // Skip to end of this declaration for the next iteration.
640 i = endIdx;
641 }
642 }
643
644 // Return valid methods, skipping overloaded names (nullopt entries).
645 return llvm::to_vector(
646 llvm::map_range(
647 llvm::make_filter_range(methods, [](const auto &p) { return p.second.has_value(); }),
648 [](const auto &p) { return p.second.value(); }
649 )
650 );
651}
652
654bool matchesMLIRClass(StringRef cppType, StringRef typeName) {
655 if (cppType == typeName) {
656 return true;
657 }
658
659 // Check for "::mlir::" or "mlir::" prefix
660 StringRef prefix = cppType;
661 prefix.consume_front("::");
662 if (prefix.consume_front("mlir::")) {
663 return prefix == typeName;
664 }
665
666 return false;
667}
668
670std::optional<std::string> tryCppTypeToCapiType(StringRef cppType) {
671 cppType = cppType.trim();
672
673 // Primitive types are unchanged
674 if (isPrimitiveType(cppType)) {
675 return std::make_optional(cppType.str());
676 }
677
678 // APInt type is converted via llzk::fromAPInt()
679 if (isAPIntType(cppType)) {
680 return std::make_optional("int64_t");
681 }
682
683 // Pointer type conversions happen via the `unwrap()` function generated
684 // by `DEFINE_C_API_PTR_METHODS()` in `mlir/CAPI/IR.h`
685 if (cppType.ends_with(" *") || cppType.ends_with("*")) {
686 size_t starPos = cppType.rfind('*');
687 if (starPos != StringRef::npos) {
688 StringRef baseType = cppType.substr(0, starPos).trim();
689 if (matchesMLIRClass(baseType, "AsmState")) {
690 return std::make_optional("MlirAsmState");
691 }
692 if (matchesMLIRClass(baseType, "BytecodeWriterConfig")) {
693 return std::make_optional("MlirBytecodeWriterConfig");
694 }
695 if (matchesMLIRClass(baseType, "MLIRContext")) {
696 return std::make_optional("MlirContext");
697 }
698 if (matchesMLIRClass(baseType, "Dialect")) {
699 return std::make_optional("MlirDialect");
700 }
701 if (matchesMLIRClass(baseType, "DialectRegistry")) {
702 return std::make_optional("MlirDialectRegistry");
703 }
704 if (matchesMLIRClass(baseType, "Operation")) {
705 return std::make_optional("MlirOperation");
706 }
707 if (matchesMLIRClass(baseType, "Block")) {
708 return std::make_optional("MlirBlock");
709 }
710 if (matchesMLIRClass(baseType, "OpOperand")) {
711 return std::make_optional("MlirOpOperand");
712 }
713 if (matchesMLIRClass(baseType, "OpPrintingFlags")) {
714 return std::make_optional("MlirOpPrintingFlags");
715 }
716 if (matchesMLIRClass(baseType, "Region")) {
717 return std::make_optional("MlirRegion");
718 }
719 if (matchesMLIRClass(baseType, "SymbolTable")) {
720 return std::make_optional("MlirSymbolTable");
721 }
722 } else {
723 llvm::errs() << "Error: Failed to parse pointer type: " << cppType << '\n';
724 }
725 }
726
727 // These have `wrap()`/`unwrap()` generated by `DEFINE_C_API_METHODS()` in...
728 // ... `mlir/CAPI/IR.h`
729 if (matchesMLIRClass(cppType, "Attribute")) {
730 return std::make_optional("MlirAttribute");
731 }
732 if (matchesMLIRClass(cppType, "StringAttr")) {
733 return std::make_optional("MlirIdentifier");
734 }
735 if (matchesMLIRClass(cppType, "Location")) {
736 return std::make_optional("MlirLocation");
737 }
738 if (matchesMLIRClass(cppType, "ModuleOp")) {
739 return std::make_optional("MlirModule");
740 }
741 if (matchesMLIRClass(cppType, "Type")) {
742 return std::make_optional("MlirType");
743 }
744 if (matchesMLIRClass(cppType, "Value")) {
745 return std::make_optional("MlirValue");
746 }
747 // ... `mlir/CAPI/AffineExpr.h`
748 if (matchesMLIRClass(cppType, "AffineExpr")) {
749 return std::make_optional("MlirAffineExpr");
750 }
751 // ... `mlir/CAPI/AffineMap.h`
752 if (matchesMLIRClass(cppType, "AffineMap")) {
753 return std::make_optional("MlirAffineMap");
754 }
755 // ... `mlir/CAPI/IntegerSet.h`
756 if (matchesMLIRClass(cppType, "IntegerSet")) {
757 return std::make_optional("MlirIntegerSet");
758 }
759 // ... `mlir/CAPI/Support.h`
760 if (matchesMLIRClass(cppType, "TypeID")) {
761 return std::make_optional("MlirTypeID");
762 }
763
764 // These have `wrap()`/`unwrap()` manually defined in `mlir/CAPI/Support.h`
765 if (matchesMLIRClass(cppType, "StringRef")) {
766 return std::make_optional("MlirStringRef");
767 }
768 if (matchesMLIRClass(cppType, "LogicalResult")) {
769 return std::make_optional("MlirLogicalResult");
770 }
771
772 // Heuristically map custom dialect classes to their C API equivalents
773 if (cppType.ends_with("Type")) {
774 return std::make_optional("MlirType");
775 }
776 if (cppType.ends_with("Attr")) {
777 return std::make_optional("MlirAttribute");
778 }
779 if (cppType.ends_with("Op")) {
780 return std::make_optional("MlirOperation");
781 }
782
783 // Otherwise, not sure how to convert it
784 return std::nullopt;
785}
786
787// Map C++ type to corresponding C API type
788std::string mapCppTypeToCapiType(StringRef cppType) {
789 assert(!isArrayRefType(cppType) && "must check `isArrayRefType()` outside");
790
791 std::optional<std::string> capiTypeOpt = tryCppTypeToCapiType(cppType);
792 if (capiTypeOpt.has_value()) {
793 return capiTypeOpt.value();
794 }
795
796 // Otherwise assume it's a type where the C name is a direct translation from the C++ name.
797 return toPascalCase(cppType);
798}
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
std::string mapCppTypeToCapiType(StringRef cppType)
AccessLevel
Access level tracking for C++ class declarations.
SmallVector< ExtraMethod > parseExtraMethods(StringRef extraDecl)
Parse method declarations from extraClassDeclaration using Clang's Lexer.
bool matchesMLIRClass(StringRef cppType, StringRef typeName)
Check if a C++ type matches an MLIR type pattern.
llvm::cl::OptionCategory OpGenCat
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
bool isCppModifierKeyword(mlir::StringRef tokenText)
Check if a token text represents a C++ modifier/specifier keyword.
bool isAPIntType(mlir::StringRef cppType)
Check if a C++ type is APInt.
llvm::cl::opt< std::string > FunctionPrefix
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
void warnSkipped(const S &methodName, const std::string &message)
Print warning about skipping a function.
bool isCppLanguageConstruct(mlir::StringRef methodName)
Check if a method name represents a C++ control flow keyword or language construct.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Definition LICENSE.txt:28
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated documentation
Definition LICENSE.txt:32
RAII wrapper for Clang lexer infrastructure.
clang::SourceManager & getSourceManager() const
Get the source manager instance.
bool isValid() const
Check if the lexer was successfully created.
ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName="input")
Construct a lexer context for the given source code.
clang::Lexer & getLexer() const
Get the lexer instance.
IntrusiveRefCntPtr< DiagnosticOptions > diagOpts
Diagnostic options for configuring diagnostics.
IntrusiveRefCntPtr< FileManager > fileMgr
File manager for handling virtual files.
IntrusiveRefCntPtr< DiagnosticIDs > diagIDs
Diagnostic IDs for error reporting.
LangOptions langOpts
C++ language options for lexer configuration.
std::unique_ptr< SourceManager > sourceMgr
Source manager for tracking file locations.
std::unique_ptr< Lexer > lexer
The actual lexer instance.
std::unique_ptr< DiagnosticsEngine > diags
Diagnostics engine for handling errors and warnings.
Structure to represent a parsed method signature from an extraClassDeclaration
bool isConst
Whether the method is const-qualified.
bool hasParameters
Whether the method has parameters (unsupported for now)
std::vector< MethodParameter > parameters
The parameters of the method.
std::string returnType
The C++ return type of the method.
std::string methodName
The name of the method.
std::string documentation
Documentation comment (if any)
Structure to represent a parameter in a parsed method signature from an extraClassDeclaration