LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
CommonCAPIGen.h
Go to the documentation of this file.
1//===- CommonCAPIGen.h - Common utilities for C API generation ------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Common utilities shared between all CAPI generators (ops, attrs, types)
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15
16#include <mlir/TableGen/Dialect.h>
17
18#include <llvm/ADT/StringExtras.h>
19#include <llvm/ADT/StringRef.h>
20#include <llvm/ADT/StringSwitch.h>
21#include <llvm/Support/CommandLine.h>
22#include <llvm/Support/FormatVariadic.h>
23
24#include <memory>
25#include <string>
26
27constexpr bool WARN_SKIPPED_METHODS = false;
28
30template <typename S> inline void warnSkipped(const S &methodName, const std::string &message) {
32 llvm::errs() << "Warning: Skipping method '" << methodName << "' - " << message << '\n';
33 }
34}
35
37template <typename S>
38inline void warnSkippedNoConversion(const S &methodName, const std::string &cppType) {
40 warnSkipped(methodName, "no conversion to C API type for '" + cppType + '\'');
41 }
42}
43
44// Forward declarations for Clang classes
45namespace clang {
46class Lexer;
47class SourceManager;
48} // namespace clang
49
50// Shared command-line options used by all CAPI generators
51extern llvm::cl::OptionCategory OpGenCat;
52extern llvm::cl::opt<std::string> DialectName;
53extern llvm::cl::opt<std::string> FunctionPrefix;
54
55// Shared flags for controlling code generation
56extern llvm::cl::opt<bool> GenIsA;
57extern llvm::cl::opt<bool> GenOpBuild;
58extern llvm::cl::opt<bool> GenOpOperandGetters;
59extern llvm::cl::opt<bool> GenOpOperandSetters;
60extern llvm::cl::opt<bool> GenOpAttributeGetters;
61extern llvm::cl::opt<bool> GenOpAttributeSetters;
62extern llvm::cl::opt<bool> GenOpRegionGetters;
63extern llvm::cl::opt<bool> GenOpResultGetters;
64extern llvm::cl::opt<bool> GenTypeOrAttrGet;
65extern llvm::cl::opt<bool> GenTypeOrAttrParamGetters;
66extern llvm::cl::opt<bool> GenExtraClassMethods;
67
75inline std::string toPascalCase(mlir::StringRef str) {
76 if (str.empty()) {
77 return "";
78 }
79
80 std::string result;
81 result.reserve(str.size());
82 llvm::raw_string_ostream resultStream(result);
83 bool capitalizeNext = true;
84
85 for (char c : str) {
86 if (c == '_' || c == ':') {
87 capitalizeNext = true;
88 } else {
89 resultStream << (capitalizeNext ? llvm::toUpper(c) : c);
90 capitalizeNext = false;
91 }
92 }
93
94 return result;
95}
96
100inline bool isIntegerType(mlir::StringRef type) {
101 // Consume optional root namespace token
102 type.consume_front("::");
103 // Handle special names first
104 if (type == "signed" || type == "unsigned" || type == "size_t" || type == "char32_t" ||
105 type == "char16_t" || type == "char8_t" || type == "wchar_t") {
106 return true;
107 }
108 // Handle standard integer types with optional signed/unsigned prefix
109 type.consume_front("signed ") || type.consume_front("unsigned ");
110 if (type == "char" || type == "int" || type == "short" || type == "short int" || type == "long" ||
111 type == "long int" || type == "long long" || type == "long long int") {
112 return true;
113 }
114 // Handle fixed-width integer types (https://cppreference.com/w/cpp/types/integer.html)
115 type.consume_front("std::"); // optional
116 if (type.consume_back("_t") && (type.consume_front("int") || type.consume_front("uint"))) {
117 // intmax_t, intptr_t, uintmax_t, uintptr_t
118 if (type == "max" || type == "ptr") {
119 return true;
120 }
121 // Optional "_fast" or "_least" and finally bit width to cover the rest
122 type.consume_back("_fast") || type.consume_back("_least");
123 if (type == "8" || type == "16" || type == "32" || type == "64") {
124 return true;
125 }
126 }
127 return false;
128}
129
136inline bool isPrimitiveType(mlir::StringRef cppType) {
137 cppType.consume_front("::");
138 return cppType == "void" || cppType == "bool" || cppType == "float" || cppType == "double" ||
139 cppType == "long double" || isIntegerType(cppType);
140}
141
145inline bool isCppModifierKeyword(mlir::StringRef tokenText) {
146 return llvm::StringSwitch<bool>(tokenText)
147 .Case("inline", true)
148 .Case("static", true)
149 .Case("virtual", true)
150 .Case("explicit", true)
151 .Case("constexpr", true)
152 .Case("consteval", true)
153 .Case("extern", true)
154 .Case("mutable", true)
155 .Case("friend", true)
156 .Default(false);
157}
158
162inline bool isCppLanguageConstruct(mlir::StringRef methodName) {
163 return llvm::StringSwitch<bool>(methodName)
164 .Case("if", true)
165 .Case("for", true)
166 .Case("while", true)
167 .Case("switch", true)
168 .Case("return", true)
169 .Case("sizeof", true)
170 .Case("decltype", true)
171 .Case("alignof", true)
172 .Case("typeid", true)
173 .Case("static_assert", true)
174 .Case("noexcept", true)
175 .Default(false);
176}
177
181inline bool isAPIntType(mlir::StringRef cppType) {
182 cppType.consume_front("::");
183 cppType.consume_front("llvm::") || cppType.consume_front("mlir::");
184 return cppType == "APInt";
185}
186
190inline bool isArrayRefType(mlir::StringRef cppType) {
191 cppType.consume_front("::");
192 cppType.consume_front("llvm::") || cppType.consume_front("mlir::");
193 return cppType.starts_with("ArrayRef<");
194}
195
197inline mlir::StringRef extractArrayRefElementType(mlir::StringRef cppType) {
198 assert(isArrayRefType(cppType) && "must check `isArrayRefType()` outside");
199
200 // Remove "ArrayRef<" prefix and ">" suffix
201 cppType.consume_front("::");
202 cppType.consume_front("llvm::") || cppType.consume_front("mlir::");
203 cppType.consume_front("ArrayRef<") && cppType.consume_back(">");
204 return cppType;
205}
206
216public:
220 explicit ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName = "input");
221
224 clang::Lexer &getLexer() const;
225
228 clang::SourceManager &getSourceManager() const;
229
232 bool isValid() const { return lexer != nullptr; }
233
234private:
235 struct Impl;
236 std::unique_ptr<Impl> impl;
237 clang::Lexer *lexer = nullptr;
238};
239
244 std::string type;
246 std::string name;
247
251 MethodParameter(const std::string &paramType, const std::string &paramName)
252 : type(mlir::StringRef(paramType).trim().str()),
253 name(mlir::StringRef(paramName).trim().str()) {}
254};
255
262 std::string returnType;
264 std::string methodName;
266 std::string documentation;
268 bool isConst = false;
270 bool hasParameters = false;
272 std::vector<MethodParameter> parameters;
273};
274
301llvm::SmallVector<ExtraMethod> parseExtraMethods(mlir::StringRef extraDecl);
302
307bool matchesMLIRClass(mlir::StringRef cppType, mlir::StringRef typeName);
308
312std::optional<std::string> tryCppTypeToCapiType(mlir::StringRef cppType);
313
320std::string mapCppTypeToCapiType(mlir::StringRef cppType);
321
323struct Generator {
324 Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
325 : kind(recordKind), os(outputStream), dialectNameCapitalized(toPascalCase(DialectName)) {}
326 virtual ~Generator() = default;
327
331 virtual void
332 setDialectAndClassName(const mlir::tblgen::Dialect *d, mlir::StringRef cppClassName) {
333 this->dialect = d;
334 this->className = cppClassName;
335 }
336
339 virtual void genExtraMethods(mlir::StringRef extraDecl) const {
340 if (extraDecl.empty()) {
341 return;
342 }
343 for (const ExtraMethod &method : parseExtraMethods(extraDecl)) {
344 genExtraMethod(method);
345 }
346 }
347
350 virtual void genExtraMethod(const ExtraMethod &method) const = 0;
351
352protected:
353 std::string kind;
354 llvm::raw_ostream &os;
356 const mlir::tblgen::Dialect *dialect;
357 mlir::StringRef className;
358};
359
361struct HeaderGenerator : public Generator {
363 virtual ~HeaderGenerator() = default;
364
365 virtual void genPrologue() const {
366 os << R"(
367#include "llzk-c/Builder.h"
368#include <mlir-c/IR.h>
370#ifdef __cplusplus
371extern "C" {
372#endif
373)";
374 }
375
376 virtual void genEpilogue() const {
377 os << R"(
378#ifdef __cplusplus
379}
380#endif
381)";
382 }
383
384 virtual void genIsADecl() const {
385 static constexpr char fmt[] = R"(
386/* Returns true if the {1} is a {4}::{3}. */
387MLIR_CAPI_EXPORTED bool {0}{1}IsA{2}{3}(Mlir{1});
388)";
389 assert(dialect && "Dialect must be set");
390 os << llvm::formatv(
391 fmt,
392 FunctionPrefix, // {0}
393 kind, // {1}
395 className, // {3}
396 dialect->getCppNamespace() // {4}
397 );
398 }
399
401 virtual void genExtraMethod(const ExtraMethod &method) const override {
402 // Convert return type to C API type, skip if it can't be converted
403 std::optional<std::string> capiReturnTypeOpt = tryCppTypeToCapiType(method.returnType);
404 if (!capiReturnTypeOpt.has_value()) {
406 return;
407 }
408 std::string capiReturnType = capiReturnTypeOpt.value();
409
410 // Build parameter list
411 std::string paramList;
412 llvm::raw_string_ostream paramListStream(paramList);
413 paramListStream << llvm::formatv("Mlir{0} inp", kind);
414 for (const auto &param : method.parameters) {
415 // Convert C++ type to C API type for parameter, skip if it can't be converted
416 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(param.type);
417 if (!capiParamTypeOpt.has_value()) {
418 warnSkippedNoConversion(method.methodName, param.type);
419 return;
420 }
421 std::string capiParamType = capiParamTypeOpt.value();
422 paramListStream << ", " << capiParamType << ' ' << param.name;
423 }
424
425 // Generate declaration
426 std::string docComment =
427 method.documentation.empty() ? method.methodName : method.documentation;
428
429 os << llvm::formatv("\n/* {0} */\n", docComment);
430 os << llvm::formatv(
431 "MLIR_CAPI_EXPORTED {0} {1}{2}{3}{4}({5});\n",
432 capiReturnType, // {0}
433 FunctionPrefix, // {1}
435 className, // {3}
436 toPascalCase(method.methodName), // {4}
437 paramList // {5}
438 );
439 }
440};
441
445 virtual ~ImplementationGenerator() = default;
446
447 virtual void genIsAImpl() const {
448 static constexpr char fmt[] = R"(
449bool {0}{1}IsA{2}{3}(Mlir{1} inp) {{
450 return llvm::isa<{3}>(unwrap(inp));
451}
452)";
453 assert(!className.empty() && "className must be set");
455 }
456
458 virtual void genExtraMethod(const ExtraMethod &method) const override {
459 // Convert return type to C API type, skip if it can't be converted
460 std::optional<std::string> capiReturnTypeOpt = tryCppTypeToCapiType(method.returnType);
461 if (!capiReturnTypeOpt.has_value()) {
463 return;
464 }
465 std::string capiReturnType = capiReturnTypeOpt.value();
466
467 // Build the return statement prefix and suffix
468 std::string returnPrefix;
469 std::string returnSuffix;
470 mlir::StringRef cppReturnType = method.returnType;
471
472 if (cppReturnType == "void") {
473 // "void" type doesn't even need "return"
474 returnPrefix = "";
475 returnSuffix = "";
476 } else {
477 // Check if return needs wrapping
478 if (isPrimitiveType(cppReturnType)) {
479 // Primitive types don't need wrapping
480 returnPrefix = "return ";
481 returnSuffix = "";
482 } else if (capiReturnType.starts_with("Mlir") || isAPIntType(cppReturnType)) {
483 // MLIR C API types and APInt type need wrapping
484 returnPrefix = "return wrap(";
485 returnSuffix = ")";
486 } else {
487 return;
488 }
489 }
490
491 // Build parameter list for C API function signature
492 std::string paramList;
493 llvm::raw_string_ostream paramListStream(paramList);
494 paramListStream << llvm::formatv("Mlir{0} inp", kind);
495 for (const auto &param : method.parameters) {
496 // Convert C++ type to C API type for parameter, skip if it can't be converted
497 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(param.type);
498 if (!capiParamTypeOpt.has_value()) {
499 warnSkippedNoConversion(method.methodName, param.type);
500 return;
501 }
502 std::string capiParamType = capiParamTypeOpt.value();
503 paramListStream << ", " << capiParamType << ' ' << param.name;
504 }
505
506 // Build argument list for C++ method call
507 std::string argList;
508 llvm::raw_string_ostream argListStream(argList);
509 for (size_t i = 0; i < method.parameters.size(); ++i) {
510 if (i > 0) {
511 argListStream << ", ";
512 }
513 const auto &param = method.parameters[i];
514
515 // Check if parameter needs unwrapping
516 mlir::StringRef cppParamType = param.type;
517 if (isPrimitiveType(cppParamType)) {
518 // Primitive types don't need unwrapping
519 argListStream << param.name;
520 } else if (isAPIntType(cppParamType)) {
521 // APInt needs unwrapping
522 argListStream << "unwrap(" << param.name << ')';
523 } else {
524 // Convert C++ type to C API type for parameter, skip if it can't be converted
525 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(cppParamType);
526 if (capiParamTypeOpt.has_value() && capiParamTypeOpt->starts_with("Mlir")) {
527 // MLIR C API types need unwrapping
528 argListStream << "unwrap(" << param.name << ')';
529 } else {
530 warnSkippedNoConversion(method.methodName, cppParamType.str());
531 return;
532 }
533 }
534 }
535
536 // Generate implementation
537 os << '\n';
538 os << llvm::formatv(
539 "{0} {1}{2}{3}{4}({5}) {{\n",
540 capiReturnType, // {0}
541 FunctionPrefix, // {1}
543 className, // {3}
544 toPascalCase(method.methodName), // {4}
545 paramList // {5}
546 );
547 os << llvm::formatv(
548 " {0}llvm::cast<{1}>(unwrap(inp)).{2}({3}){4};\n",
549 returnPrefix, // {0}
550 className, // {1}
551 method.methodName, // {2}
552 argList, // {3}
553 returnSuffix // {4}
554 );
555 os << "}\n";
556 }
557};
558
560struct TestGenerator : public Generator {
562 virtual ~TestGenerator() = default;
563
565 virtual void genTestClassPrologue() const {
566 static constexpr char fmt[] = "class {0}{1}LinkTests : public CAPITest {{};\n";
567 os << llvm::formatv(fmt, dialectNameCapitalized, kind);
568 }
569
571 virtual void genIsATest() const {
572 static constexpr char fmt[] = R"(
573// This test ensures {0}{1}IsA{2}{3} links properly.
574TEST_F({2}{1}LinkTests, IsA_{2}{3}) {{
575 auto test{1} = createIndex{1}();
576
577 // This will always return false since `createIndex*` returns an MLIR builtin
578 EXPECT_FALSE({0}{1}IsA{2}{3}(test{1}));
579
580 {4}(test{1});
581}
582)";
583 assert(!className.empty() && "className must be set");
584 os << llvm::formatv(
585 fmt,
586 FunctionPrefix, // {0}
587 kind, // {1}
589 className, // {3}
590 genCleanup() // {4}
591 );
592 }
593
595 virtual void genExtraMethod(const ExtraMethod &method) const override {
596 // Convert return type to C API type, skip if it can't be converted
597 std::optional<std::string> capiReturnTypeOpt = tryCppTypeToCapiType(method.returnType);
598 if (!capiReturnTypeOpt.has_value()) {
600 return;
601 }
602
603 // Build parameter list for dummy values
604 std::string dummyParams;
605 llvm::raw_string_ostream dummyParamsStream(dummyParams);
606 std::string paramList;
607 llvm::raw_string_ostream paramListStream(paramList);
608
609 for (const auto &param : method.parameters) {
610 // Convert C++ type to C API type for parameter, skip if it can't be converted
611 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(param.type);
612 if (!capiParamTypeOpt.has_value()) {
613 warnSkippedNoConversion(method.methodName, param.type);
614 return;
615 }
616 std::string capiParamType = capiParamTypeOpt.value();
617 std::string name = param.name;
618
619 // Generate dummy value creation for each parameter
620 if (capiParamType == "bool") {
621 dummyParamsStream << " bool " << name << " = false;\n";
622 } else if (capiParamType == "MlirValue") {
623 dummyParamsStream << " auto " << name << " = mlirOperationGetResult(testOp, 0);\n";
624 } else if (capiParamType == "MlirType") {
625 dummyParamsStream << " auto " << name << " = createIndexType();\n";
626 } else if (capiParamType == "MlirAttribute") {
627 dummyParamsStream << " auto " << name << " = createIndexAttribute();\n";
628 } else if (capiParamType == "MlirStringRef") {
629 dummyParamsStream << " auto " << name << " = mlirStringRefCreateFromCString(\"\");\n";
630 } else if (isIntegerType(capiParamType)) {
631 dummyParamsStream << " " << capiParamType << ' ' << name << " = 0;\n";
632 } else {
633 // For unknown types, create a default-initialized variable
634 dummyParamsStream << " " << capiParamType << ' ' << name << " = {};\n";
635 }
636
637 paramListStream << ", " << name;
638 }
639
640 static constexpr char fmt[] = R"(
641// This test ensures {0}{2}{3}{4} links properly.
642TEST_F({2}{1}LinkTests, {0}_{3}_{4}) {{
643 auto test{1} = createIndex{1}();
644
645 if ({0}{1}IsA{2}{3}(test{1})) {{
646{5}
647 (void){0}{2}{3}{4}(test{1}{6});
648 }
649
650 {7}(test{1});
651}
652)";
653 assert(!className.empty() && "className must be set");
654 os << llvm::formatv(
655 fmt,
656 FunctionPrefix, // {0}
657 kind, // {1}
659 className, // {3}
660 toPascalCase(method.methodName), // {4}
661 dummyParams, // {5}
662 paramList, // {6}
663 genCleanup() // {7}
664 );
665 }
666
676 virtual std::string genCleanup() const {
677 // The default case is to just comment out the rest of the cleanup line
678 return "//";
679 }
680};
mlir::StringRef extractArrayRefElementType(mlir::StringRef cppType)
Extract element type from ArrayRef<...>
llvm::cl::OptionCategory OpGenCat
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
void warnSkippedNoConversion(const S &methodName, const std::string &cppType)
Print warning about skipping a function due to no conversion of C++ type to C API type.
llvm::cl::opt< bool > GenIsA
std::string mapCppTypeToCapiType(mlir::StringRef cppType)
Map C++ type to corresponding C API type.
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
bool isCppModifierKeyword(mlir::StringRef tokenText)
Check if a token text represents a C++ modifier/specifier keyword.
bool isAPIntType(mlir::StringRef cppType)
Check if a C++ type is APInt.
llvm::cl::opt< std::string > FunctionPrefix
std::optional< std::string > tryCppTypeToCapiType(mlir::StringRef cppType)
Convert C++ type to MLIR C API type.
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
llvm::cl::opt< bool > GenOpRegionGetters
constexpr bool WARN_SKIPPED_METHODS
bool isIntegerType(mlir::StringRef type)
Check if a C++ type is a known integer type.
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
bool matchesMLIRClass(mlir::StringRef cppType, mlir::StringRef typeName)
Check if a C++ type matches an MLIR type pattern.
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
llvm::SmallVector< ExtraMethod > parseExtraMethods(mlir::StringRef extraDecl)
Parse method declarations from an extraClassDeclaration using Clang's Lexer.
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
void warnSkipped(const S &methodName, const std::string &message)
Print warning about skipping a function.
bool isCppLanguageConstruct(mlir::StringRef methodName)
Check if a method name represents a C++ control flow keyword or language construct.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Definition LICENSE.txt:28
clang::SourceManager & getSourceManager() const
Get the source manager instance.
bool isValid() const
Check if the lexer was successfully created.
ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName="input")
Construct a lexer context for the given source code.
clang::Lexer & getLexer() const
Get the lexer instance.
Structure to represent a parsed method signature from an extraClassDeclaration
bool isConst
Whether the method is const-qualified.
bool hasParameters
Whether the method has parameters (unsupported for now)
std::vector< MethodParameter > parameters
The parameters of the method.
std::string returnType
The C++ return type of the method.
std::string methodName
The name of the method.
std::string documentation
Documentation comment (if any)
virtual void setDialectAndClassName(const mlir::tblgen::Dialect *d, mlir::StringRef cppClassName)
Set the dialect and class name for code generation.
virtual ~Generator()=default
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
mlir::StringRef className
virtual void genExtraMethods(mlir::StringRef extraDecl) const
Generate code for extra methods from an extraClassDeclaration
const mlir::tblgen::Dialect * dialect
virtual void genExtraMethod(const ExtraMethod &method) const =0
Generate code for an extra method.
std::string dialectNameCapitalized
llvm::raw_ostream & os
std::string kind
Generator for common C header file elements.
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
virtual void genPrologue() const
virtual ~HeaderGenerator()=default
virtual void genEpilogue() const
virtual void genIsADecl() const
virtual void genExtraMethod(const ExtraMethod &method) const override
Generate declaration for an extra method from an extraClassDeclaration
Generator for common C implementation file elements.
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
virtual ~ImplementationGenerator()=default
virtual void genIsAImpl() const
virtual void genExtraMethod(const ExtraMethod &method) const override
Generate implementation for an extra method from an extraClassDeclaration
std::string name
The name of the parameter.
std::string type
The C++ type of the parameter.
MethodParameter(const std::string &paramType, const std::string &paramName)
Construct a new Method Parameter object.
Generator for common test implementation file elements.
virtual void genTestClassPrologue() const
Generate the test class prologue.
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
virtual ~TestGenerator()=default
virtual void genIsATest() const
Generate IsA test for a class.
virtual std::string genCleanup() const
Generate cleanup code for test methods.
virtual void genExtraMethod(const ExtraMethod &method) const override
Generate test for an extra method from extraClassDeclaration.