14#include <mlir/TableGen/GenInfo.h>
15#include <mlir/TableGen/Operator.h>
17#include <llvm/Support/CommandLine.h>
18#include <llvm/Support/FormatVariadic.h>
19#include <llvm/TableGen/Error.h>
20#include <llvm/TableGen/Record.h>
21#include <llvm/TableGen/TableGenBackend.h>
30using namespace mlir::tblgen;
38 name.empty() ? llvm::formatv(
"Result{0}", resultIndex).str() :
toPascalCase(name);
42 name.empty() ? llvm::formatv(
"Region{0}", regionIndex).str() :
toPascalCase(name);
54 using HeaderGenerator::HeaderGenerator;
58 static constexpr char fmt[] = R
"(
59/* Build a {4}::{2} Operation. */
60MLIR_CAPI_EXPORTED MlirOperation {0}{1}{2}Build(MlirOpBuilder builder, MlirLocation location{3});
62 assert(!className.empty() && "className must be set");
74 static constexpr char fmt[] = R
"(
75/* Get {3} operand from {4}::{2} Operation. */
76MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}(MlirOperation op);
78 assert(!className.empty() && "className must be set");
91 static constexpr char fmt[] = R
"(
92/* Set {3} operand of {4}::{2} Operation. */
93MLIR_CAPI_EXPORTED void {0}{1}{2}Set{3}(MlirOperation op, MlirValue value);
95 assert(!className.empty() && "className must be set");
108 static constexpr char fmt[] = R
"(
109/* Get number of {3} operands in {4}::{2} Operation. */
110MLIR_CAPI_EXPORTED intptr_t {0}{1}{2}Get{3}Count(MlirOperation op);
112/* Get {3} operand at index from {4}::{2} Operation. */
113MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index);
115 assert(!className.empty() && "className must be set");
129 static constexpr char fmt[] = R
"(
130/* Set {3} operands of {4}::{2} Operation. */
131MLIR_CAPI_EXPORTED void {0}{1}{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values);
133 assert(!className.empty() && "className must be set");
147 static constexpr char fmt[] = R
"(
148/* Get {3} attribute from {4}::{2} Operation. */
149MLIR_CAPI_EXPORTED MlirAttribute {0}{1}{2}Get{3}(MlirOperation op);
151 assert(!className.empty() && "className must be set");
164 static constexpr char fmt[] = R
"(
165/* Set {3} attribute of {4}::{2} Operation. */
166MLIR_CAPI_EXPORTED void {0}{1}{2}Set{3}(MlirOperation op, MlirAttribute attr);
168 assert(!className.empty() && "className must be set");
181 static constexpr char fmt[] = R
"(
182/* Get {3} result from {4}::{2} Operation. */
183MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}(MlirOperation op);
185 assert(!className.empty() && "className must be set");
198 static constexpr char fmt[] = R
"(
199/* Get number of {3} results in {4}::{2} Operation. */
200MLIR_CAPI_EXPORTED intptr_t {0}{1}{2}Get{3}Count(MlirOperation op);
202/* Get {3} result at index from {4}::{2} Operation. */
203MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index);
205 assert(!className.empty() && "className must be set");
218 static constexpr char fmt[] = R
"(
219/* Get {3} region from {4}::{2} Operation. */
220MLIR_CAPI_EXPORTED MlirRegion {0}{1}{2}Get{3}(MlirOperation op);
222 assert(!className.empty() && "className must be set");
235 static constexpr char fmt[] = R
"(
236/* Get number of {3} regions in {4}::{2} Operation. */
237MLIR_CAPI_EXPORTED intptr_t {0}{1}{2}Get{3}Count(MlirOperation op);
239/* Get {3} region at index from {4}::{2} Operation. */
240MLIR_CAPI_EXPORTED MlirRegion {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index);
242 assert(!className.empty() && "className must be set");
260static std::string generateCAPIBuildParams(
const Operator &op) {
263 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
265 if (result.isVariadic()) {
266 os << llvm::formatv(
", intptr_t {0}Size, MlirType const *{0}Types", resultName);
268 os << llvm::formatv(
", MlirType {0}Type", resultName);
271 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
272 if (operand.isVariadic()) {
273 os << llvm::formatv(
", intptr_t {0}Size, MlirValue const *{0}", operand.name);
275 os << llvm::formatv(
", MlirValue {0}", operand.name);
278 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
280 os << llvm::formatv(
", {0} {1}", attrType.value_or(
"MlirAttribute"), attr.name);
282 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
283 if (region.isVariadic()) {
284 os << llvm::formatv(
", unsigned {0}Count", region.name);
287 } paramStringGenerator;
288 return paramStringGenerator.gen(op);
292static bool emitOpCAPIHeader(
const llvm::RecordKeeper &records, raw_ostream &os) {
293 emitSourceFileHeader(
"Op C API Declarations", os, records);
296 generator.genPrologue();
298 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
299 const Operator op(def);
300 const Dialect &dialect = op.getDialect();
307 generator.setDialectAndClassName(&dialect, op.getCppClassName());
310 if (
GenOpBuild && !op.skipDefaultBuilders()) {
311 generator.genOpBuildDecl(generateCAPIBuildParams(op));
316 generator.genIsADecl();
320 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
321 const auto &operand = op.getOperand(i);
322 generator.setOperandName(operand.name);
323 if (operand.isVariadic()) {
325 generator.genVariadicOperandGetterDecl();
328 generator.genVariadicOperandSetterDecl();
332 generator.genOperandGetterDecl();
335 generator.genOperandSetterDecl();
341 for (
const auto &namedAttr : op.getAttributes()) {
342 generator.setAttributeName(namedAttr.name);
344 generator.genAttributeGetterDecl();
347 generator.genAttributeSetterDecl();
353 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
354 const auto &result = op.getResult(i);
355 generator.setResultName(result.name, i);
356 if (result.isVariadic()) {
357 generator.genVariadicResultGetterDecl();
359 generator.genResultGetterDecl();
366 for (
unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
367 const auto ®ion = op.getRegion(i);
368 generator.setRegionName(region.name, i);
369 if (region.isVariadic()) {
370 generator.genVariadicRegionGetterDecl();
372 generator.genRegionGetterDecl();
379 generator.genExtraMethods(op.getExtraClassDeclaration());
383 generator.genEpilogue();
389 using ImplementationGenerator::ImplementationGenerator;
397 const std::string &operationName,
const std::string ¶ms,
const std::string &assignments
399 static constexpr char fmt[] = R
"(
400MlirOperation {0}{1}{2}Build(MlirOpBuilder builder, MlirLocation location{3}) {{
401 MlirOperationState state = mlirOperationStateGet(mlirStringRefCreateFromCString("{4}"), location);
403 return mlirOpBuilderInsert(builder, mlirOperationCreate(&state));
406 assert(!className.empty() && "className must be set");
419 static constexpr char fmt[] = R
"(
420MlirValue {0}{1}{2}Get{3}(MlirOperation op) {{
421 return mlirOperationGetOperand(op, {4});
424 assert(!className.empty() && "className must be set");
437 static constexpr char fmt[] = R
"(
438void {0}{1}{2}Set{3}(MlirOperation op, MlirValue value) {{
439 mlirOperationSetOperand(op, {4}, value);
442 assert(!className.empty() && "className must be set");
455 static constexpr char fmt[] = R
"(
456intptr_t {0}{1}{2}Get{3}Count(MlirOperation op) {{
457 intptr_t count = mlirOperationGetNumOperands(op);
458 assert(count >= {4} && "operand count less than start index");
462MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index) {{
463 return mlirOperationGetOperand(op, {4} + index);
466 assert(!className.empty() && "className must be set");
479 static constexpr char fmt[] = R
"(
480void {0}{1}{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values) {{
481 intptr_t numOperands = mlirOperationGetNumOperands(op);
482 intptr_t startIdx = {4};
485 if (startIdx < 0 || startIdx > numOperands) {{
488 if (count < 0 || count > (std::numeric_limits<intptr_t>::max() - startIdx)) {{
492 intptr_t oldCount = numOperands - startIdx;
493 intptr_t newNumOperands = startIdx + count;
495 std::vector<MlirValue> newOperands(newNumOperands);
497 // Copy operands before this variadic group
498 for (intptr_t i = 0; i < startIdx; ++i) {{
499 newOperands[i] = mlirOperationGetOperand(op, i);
502 // Copy new variadic operands
503 for (intptr_t i = 0; i < count; ++i) {{
504 newOperands[startIdx + i] = values[i];
507 // Copy operands after this variadic group
508 for (intptr_t i = startIdx + oldCount; i < numOperands; ++i) {{
509 newOperands[i - oldCount + count] = mlirOperationGetOperand(op, i);
512 mlirOperationSetOperands(op, newNumOperands, newOperands.data());
515 assert(!className.empty() && "className must be set");
528 static constexpr char fmt[] = R
"(
529MlirAttribute {0}{1}{2}Get{3}(MlirOperation op) {{
530 return mlirOperationGetAttributeByName(op, mlirStringRefCreateFromCString("{4}"));
533 assert(!className.empty() && "className must be set");
546 static constexpr char fmt[] = R
"(
547void {0}{1}{2}Set{3}(MlirOperation op, MlirAttribute attr) {{
548 mlirOperationSetAttributeByName(op, mlirStringRefCreateFromCString("{4}"), attr);
551 assert(!className.empty() && "className must be set");
564 static constexpr char fmt[] = R
"(
565MlirValue {0}{1}{2}Get{3}(MlirOperation op) {{
566 return mlirOperationGetResult(op, {4});
569 assert(!className.empty() && "className must be set");
582 static constexpr char fmt[] = R
"(
583intptr_t {0}{1}{2}Get{3}Count(MlirOperation op) {{
584 intptr_t count = mlirOperationGetNumResults(op);
585 assert(count >= {4} && "result count less than start index");
589MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index) {{
590 return mlirOperationGetResult(op, {4} + index);
593 assert(!className.empty() && "className must be set");
606 static constexpr char fmt[] = R
"(
607MlirRegion {0}{1}{2}Get{3}(MlirOperation op) {{
608 return mlirOperationGetRegion(op, {4});
611 assert(!className.empty() && "className must be set");
624 static constexpr char fmt[] = R
"(
625intptr_t {0}{1}{2}Get{3}Count(MlirOperation op) {{
626 intptr_t count = mlirOperationGetNumRegions(op);
627 assert(count >= {4} && "region count less than start index");
631MlirRegion {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index) {{
632 return mlirOperationGetRegion(op, {4} + index);
635 assert(!className.empty() && "className must be set");
653static std::string generateCAPIAssignments(
const Operator &op) {
660 void genResultInferred(llvm::raw_ostream &os)
override {
661 os <<
" mlirOperationStateEnableResultTypeInference(&state);\n";
664 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
666 if (result.isVariadic()) {
668 " mlirOperationStateAddResults(&state, {0}Size, {0}Types);\n", resultName
671 os << llvm::formatv(
" mlirOperationStateAddResults(&state, 1, &{0}Type);\n", resultName);
674 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
675 if (operand.isVariadic()) {
677 " mlirOperationStateAddOperands(&state, {0}Size, {0});\n", operand.name
680 os << llvm::formatv(
" mlirOperationStateAddOperands(&state, 1, &{0});\n", operand.name);
683 void genAttributesPrefix(llvm::raw_ostream &os,
const mlir::tblgen::Operator &op)
override {
684 os <<
" MlirContext ctx = mlirOpBuilderGetContext(builder);\n";
685 os <<
" llvm::SmallVector<MlirNamedAttribute, " << op.getNumAttributes()
686 <<
"> attributes;\n";
688 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
693 std::string attrValue;
694 if (attrType.has_value() && attrType.value() ==
"MlirIdentifier") {
695 attrValue =
"reinterpret_cast<MlirAttribute&>(" + attr.name.str() +
")";
697 attrValue = attr.name.str();
700 os <<
" if (!mlirAttributeIsNull(" << attrValue <<
")) {\n";
701 os <<
" attributes.push_back(mlirNamedAttributeGet(mlirIdentifierGet(ctx, "
702 <<
"mlirStringRefCreateFromCString(\"" << attr.name <<
"\")), " << attrValue <<
"));\n";
705 void genAttributesSuffix(llvm::raw_ostream &os,
const mlir::tblgen::Operator &op)
override {
706 os <<
" mlirOperationStateAddAttributes(&state, attributes.size(), attributes.data());\n";
708 void genRegionsPrefix(llvm::raw_ostream &os,
const mlir::tblgen::Operator &op)
override {
709 os <<
" llvm::SmallVector<MlirRegion, " << op.getNumRegions() <<
"> regions;\n";
711 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
712 if (region.isVariadic()) {
713 os << llvm::formatv(
" for (unsigned i = 0; i < {0}Count; ++i)\n ", region.name);
715 os <<
" regions.push_back(mlirRegionCreate());\n";
717 void genRegionsSuffix(llvm::raw_ostream &os,
const mlir::tblgen::Operator &op)
override {
718 os <<
" mlirOperationStateAddOwnedRegions(&state, regions.size(), regions.data());\n";
720 } paramStringGenerator;
721 return paramStringGenerator.gen(op);
725static bool emitOpCAPIImpl(
const llvm::RecordKeeper &records, raw_ostream &os) {
726 emitSourceFileHeader(
"Op C API Definitions", os, records);
733 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
734 const Operator op(def);
735 const Dialect &dialect = op.getDialect();
736 generator.setDialectAndClassName(&dialect, op.getCppClassName());
739 if (
GenOpBuild && !op.skipDefaultBuilders()) {
740 std::string assignments = generateCAPIAssignments(op);
741 generator.genOpBuildImpl(op.getOperationName(), generateCAPIBuildParams(op), assignments);
746 generator.genIsAImpl();
750 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
751 const auto &operand = op.getOperand(i);
752 generator.setOperandName(operand.name);
753 if (operand.isVariadic()) {
755 generator.genVariadicOperandGetterImpl(i);
758 generator.genVariadicOperandSetterImpl(i);
762 generator.genOperandGetterImpl(i);
765 generator.genOperandSetterImpl(i);
771 for (
const auto &namedAttr : op.getAttributes()) {
772 generator.setAttributeName(namedAttr.name);
774 generator.genAttributeGetterImpl(namedAttr.name);
777 generator.genAttributeSetterImpl(namedAttr.name);
783 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
784 const auto &result = op.getResult(i);
785 generator.setResultName(result.name, i);
786 if (result.isVariadic()) {
787 generator.genVariadicResultGetterImpl(i);
789 generator.genResultGetterImpl(i);
796 for (
unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
797 const auto ®ion = op.getRegion(i);
798 generator.setRegionName(region.name, i);
799 if (region.isVariadic()) {
800 generator.genVariadicRegionGetterImpl(i);
802 generator.genRegionGetterImpl(i);
809 generator.genExtraMethods(op.getExtraClassDeclaration());
816static mlir::GenRegistration
817 genOpCAPIHeader(
"gen-op-capi-header",
"Generate operation C API header", &emitOpCAPIHeader);
819static mlir::GenRegistration
820 genOpCAPIImpl(
"gen-op-capi-impl",
"Generate operation C API implementation", &emitOpCAPIImpl);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Helper struct to generate a string from operation operand, attribute, and result pieces.
mlir::StringRef className
const mlir::tblgen::Dialect * dialect
std::string dialectNameCapitalized
Generator for common C implementation file elements.
Common between header and implementation generators for operations.
void setAttributeName(mlir::StringRef name)
void setResultName(mlir::StringRef name, int resultIndex)
void setOperandName(mlir::StringRef name)
std::string resultNameCapitalized
std::string attrNameCapitalized
std::string operandNameCapitalized
void setRegionName(mlir::StringRef name, unsigned regionIndex)
std::string regionNameCapitalized
Generator for operation C implementation files.
void genOperandGetterImpl(int index) const
void genVariadicOperandGetterImpl(int startIdx) const
void genVariadicOperandSetterImpl(int startIdx) const
void genRegionGetterImpl(unsigned index) const
virtual ~OpImplementationGenerator()=default
void genAttributeGetterImpl(mlir::StringRef attrName) const
void genVariadicRegionGetterImpl(unsigned startIdx) const
void genVariadicResultGetterImpl(int startIdx) const
void genOpBuildImpl(const std::string &operationName, const std::string ¶ms, const std::string &assignments) const
Generate operation "Build" function implementation.
void genResultGetterImpl(int index) const
void genOperandSetterImpl(int index) const
void genAttributeSetterImpl(mlir::StringRef attrName) const