LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
OpCAPIGen.cpp
Go to the documentation of this file.
1//===- OpCAPIGen.cpp - C API generator for operations ---------------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// OpCAPIGen uses the description of operations to generate C API for the ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include <mlir/TableGen/GenInfo.h>
15#include <mlir/TableGen/Operator.h>
16
17#include <llvm/Support/CommandLine.h>
18#include <llvm/Support/FormatVariadic.h>
19#include <llvm/TableGen/Error.h>
20#include <llvm/TableGen/Record.h>
21#include <llvm/TableGen/TableGenBackend.h>
22
23#include <string>
24#include <vector>
25
26#include "CommonCAPIGen.h"
27#include "OpCAPIParamHelper.h"
28
29using namespace mlir;
30using namespace mlir::tblgen;
31
34 void setOperandName(mlir::StringRef name) { this->operandNameCapitalized = toPascalCase(name); }
35 void setAttributeName(mlir::StringRef name) { this->attrNameCapitalized = toPascalCase(name); }
36 void setResultName(mlir::StringRef name, int resultIndex) {
38 name.empty() ? llvm::formatv("Result{0}", resultIndex).str() : toPascalCase(name);
39 }
40 void setRegionName(mlir::StringRef name, unsigned regionIndex) {
42 name.empty() ? llvm::formatv("Region{0}", regionIndex).str() : toPascalCase(name);
43 }
44
45protected:
50};
51
54 using HeaderGenerator::HeaderGenerator;
55 virtual ~OpHeaderGenerator() = default;
56
57 void genOpBuildDecl(const std::string &params) const {
58 static constexpr char fmt[] = R"(
59/* Build a {4}::{2} Operation. */
60MLIR_CAPI_EXPORTED MlirOperation {0}{1}{2}Build(MlirOpBuilder builder, MlirLocation location{3});
61)";
62 assert(!className.empty() && "className must be set");
63 os << llvm::formatv(
64 fmt,
65 FunctionPrefix, // {0}
67 className, // {2}
68 params, // {3}
69 dialect->getCppNamespace() // {4}
70 );
71 }
72
73 void genOperandGetterDecl() const {
74 static constexpr char fmt[] = R"(
75/* Get {3} operand from {4}::{2} Operation. */
76MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}(MlirOperation op);
77)";
78 assert(!className.empty() && "className must be set");
79 assert(!operandNameCapitalized.empty() && "operandName must be set");
80 os << llvm::formatv(
81 fmt,
82 FunctionPrefix, // {0}
84 className, // {2}
86 dialect->getCppNamespace() // {4}
87 );
88 }
89
90 void genOperandSetterDecl() const {
91 static constexpr char fmt[] = R"(
92/* Set {3} operand of {4}::{2} Operation. */
93MLIR_CAPI_EXPORTED void {0}{1}{2}Set{3}(MlirOperation op, MlirValue value);
94)";
95 assert(!className.empty() && "className must be set");
96 assert(!operandNameCapitalized.empty() && "operandName must be set");
97 os << llvm::formatv(
98 fmt,
99 FunctionPrefix, // {0}
101 className, // {2}
103 dialect->getCppNamespace() // {4}
104 );
105 }
106
107 void genVariadicOperandGetterDecl() const {
108 static constexpr char fmt[] = R"(
109/* Get number of {3} operands in {4}::{2} Operation. */
110MLIR_CAPI_EXPORTED intptr_t {0}{1}{2}Get{3}Count(MlirOperation op);
111
112/* Get {3} operand at index from {4}::{2} Operation. */
113MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index);
114)";
115 assert(!className.empty() && "className must be set");
116 assert(!operandNameCapitalized.empty() && "operandName must be set");
117 os << llvm::formatv(
118 fmt,
119 FunctionPrefix, // {0}
121 className, // {2}
123 dialect->getCppNamespace() // {4}
124
125 );
126 }
127
129 static constexpr char fmt[] = R"(
130/* Set {3} operands of {4}::{2} Operation. */
131MLIR_CAPI_EXPORTED void {0}{1}{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values);
132)";
133 assert(!className.empty() && "className must be set");
134 assert(!operandNameCapitalized.empty() && "operandName must be set");
135 os << llvm::formatv(
136 fmt,
137 FunctionPrefix, // {0}
139 className, // {2}
141 dialect->getCppNamespace() // {4}
143 );
144 }
145
146 void genAttributeGetterDecl() const {
147 static constexpr char fmt[] = R"(
148/* Get {3} attribute from {4}::{2} Operation. */
149MLIR_CAPI_EXPORTED MlirAttribute {0}{1}{2}Get{3}(MlirOperation op);
150)";
151 assert(!className.empty() && "className must be set");
152 assert(!attrNameCapitalized.empty() && "attrName must be set");
153 os << llvm::formatv(
154 fmt,
155 FunctionPrefix, // {0}
157 className, // {2}
158 attrNameCapitalized, // {3}
159 dialect->getCppNamespace() // {4}
160 );
161 }
162
163 void genAttributeSetterDecl() const {
164 static constexpr char fmt[] = R"(
165/* Set {3} attribute of {4}::{2} Operation. */
166MLIR_CAPI_EXPORTED void {0}{1}{2}Set{3}(MlirOperation op, MlirAttribute attr);
167)";
168 assert(!className.empty() && "className must be set");
169 assert(!attrNameCapitalized.empty() && "attrName must be set");
170 os << llvm::formatv(
171 fmt,
172 FunctionPrefix, // {0}
174 className, // {2}
175 attrNameCapitalized, // {3}
176 dialect->getCppNamespace() // {4}
177 );
178 }
179
180 void genResultGetterDecl() const {
181 static constexpr char fmt[] = R"(
182/* Get {3} result from {4}::{2} Operation. */
183MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}(MlirOperation op);
184)";
185 assert(!className.empty() && "className must be set");
186 assert(!resultNameCapitalized.empty() && "resultName must be set");
187 os << llvm::formatv(
188 fmt,
189 FunctionPrefix, // {0}
191 className, // {2}
193 dialect->getCppNamespace() // {4}
194 );
195 }
196
197 void genVariadicResultGetterDecl() const {
198 static constexpr char fmt[] = R"(
199/* Get number of {3} results in {4}::{2} Operation. */
200MLIR_CAPI_EXPORTED intptr_t {0}{1}{2}Get{3}Count(MlirOperation op);
201
202/* Get {3} result at index from {4}::{2} Operation. */
203MLIR_CAPI_EXPORTED MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index);
204)";
205 assert(!className.empty() && "className must be set");
206 assert(!resultNameCapitalized.empty() && "resultName must be set");
207 os << llvm::formatv(
208 fmt,
209 FunctionPrefix, // {0}
211 className, // {2}
213 dialect->getCppNamespace() // {4}
214 );
215 }
216
217 void genRegionGetterDecl() const {
218 static constexpr char fmt[] = R"(
219/* Get {3} region from {4}::{2} Operation. */
220MLIR_CAPI_EXPORTED MlirRegion {0}{1}{2}Get{3}(MlirOperation op);
221)";
222 assert(!className.empty() && "className must be set");
223 assert(!regionNameCapitalized.empty() && "regionName must be set");
224 os << llvm::formatv(
225 fmt,
226 FunctionPrefix, // {0}
228 className, // {2}
230 dialect->getCppNamespace() // {4}
231 );
232 }
233
234 void genVariadicRegionGetterDecl() const {
235 static constexpr char fmt[] = R"(
236/* Get number of {3} regions in {4}::{2} Operation. */
237MLIR_CAPI_EXPORTED intptr_t {0}{1}{2}Get{3}Count(MlirOperation op);
238
239/* Get {3} region at index from {4}::{2} Operation. */
240MLIR_CAPI_EXPORTED MlirRegion {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index);
241)";
242 assert(!className.empty() && "className must be set");
243 assert(!regionNameCapitalized.empty() && "regionName must be set");
244 os << llvm::formatv(
245 fmt,
246 FunctionPrefix, // {0}
248 className, // {2}
250 dialect->getCppNamespace() // {4}
251 );
252 }
253};
254
260static std::string generateCAPIBuildParams(const Operator &op) {
261 struct : GenStringFromOpPieces {
262 void genResult(
263 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
264 ) override {
265 if (result.isVariadic()) {
266 os << llvm::formatv(", intptr_t {0}Size, MlirType const *{0}Types", resultName);
267 } else {
268 os << llvm::formatv(", MlirType {0}Type", resultName);
269 }
270 }
271 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
272 if (operand.isVariadic()) {
273 os << llvm::formatv(", intptr_t {0}Size, MlirValue const *{0}", operand.name);
274 } else {
275 os << llvm::formatv(", MlirValue {0}", operand.name);
276 }
277 }
278 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
279 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
280 os << llvm::formatv(", {0} {1}", attrType.value_or("MlirAttribute"), attr.name);
281 }
282 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
283 if (region.isVariadic()) {
284 os << llvm::formatv(", unsigned {0}Count", region.name);
285 }
286 }
287 } paramStringGenerator;
288 return paramStringGenerator.gen(op);
289}
290
292static bool emitOpCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
293 emitSourceFileHeader("Op C API Declarations", os, records);
294
295 OpHeaderGenerator generator("Operation", os);
296 generator.genPrologue();
297
298 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
299 const Operator op(def);
300 const Dialect &dialect = op.getDialect();
301
302 // Generate for the selected dialect only (specified via -dialect command-line option)
303 if (dialect.getName() != DialectName) {
304 continue;
305 }
306
307 generator.setDialectAndClassName(&dialect, op.getCppClassName());
308
309 // Generate "Build" function
310 if (GenOpBuild && !op.skipDefaultBuilders()) {
311 generator.genOpBuildDecl(generateCAPIBuildParams(op));
312 }
313
314 // Generate IsA check
315 if (GenIsA) {
316 generator.genIsADecl();
317 }
318
319 // Generate operand getters and setters
320 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
321 const auto &operand = op.getOperand(i);
322 generator.setOperandName(operand.name);
323 if (operand.isVariadic()) {
325 generator.genVariadicOperandGetterDecl();
326 }
328 generator.genVariadicOperandSetterDecl();
329 }
330 } else {
332 generator.genOperandGetterDecl();
333 }
335 generator.genOperandSetterDecl();
336 }
337 }
338 }
339
340 // Generate attribute getters and setters
341 for (const auto &namedAttr : op.getAttributes()) {
342 generator.setAttributeName(namedAttr.name);
344 generator.genAttributeGetterDecl();
345 }
347 generator.genAttributeSetterDecl();
348 }
349 }
350
351 // Generate result getters
352 if (GenOpResultGetters) {
353 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
354 const auto &result = op.getResult(i);
355 generator.setResultName(result.name, i);
356 if (result.isVariadic()) {
357 generator.genVariadicResultGetterDecl();
358 } else {
359 generator.genResultGetterDecl();
360 }
361 }
362 }
363
364 // Generate region getters
365 if (GenOpRegionGetters) {
366 for (unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
367 const auto &region = op.getRegion(i);
368 generator.setRegionName(region.name, i);
369 if (region.isVariadic()) {
370 generator.genVariadicRegionGetterDecl();
371 } else {
372 generator.genRegionGetterDecl();
373 }
374 }
375 }
376
377 // Generate extra class method wrappers
379 generator.genExtraMethods(op.getExtraClassDeclaration());
380 }
381 }
382
383 generator.genEpilogue();
384 return false;
385}
386
389 using ImplementationGenerator::ImplementationGenerator;
390 virtual ~OpImplementationGenerator() = default;
391
397 const std::string &operationName, const std::string &params, const std::string &assignments
398 ) const {
399 static constexpr char fmt[] = R"(
400MlirOperation {0}{1}{2}Build(MlirOpBuilder builder, MlirLocation location{3}) {{
401 MlirOperationState state = mlirOperationStateGet(mlirStringRefCreateFromCString("{4}"), location);
402{5}
403 return mlirOpBuilderInsert(builder, mlirOperationCreate(&state));
404}
405)";
406 assert(!className.empty() && "className must be set");
407 os << llvm::formatv(
408 fmt,
409 FunctionPrefix, // {0}
411 className, // {2}
412 params, // {3}
413 operationName, // {4}
414 assignments // {5}
415 );
416 }
417
418 void genOperandGetterImpl(int index) const {
419 static constexpr char fmt[] = R"(
420MlirValue {0}{1}{2}Get{3}(MlirOperation op) {{
421 return mlirOperationGetOperand(op, {4});
422}
423)";
424 assert(!className.empty() && "className must be set");
425 assert(!operandNameCapitalized.empty() && "operandName must be set");
426 os << llvm::formatv(
427 fmt,
428 FunctionPrefix, // {0}
430 className, // {2}
432 index // {4}
433 );
434 }
435
436 void genOperandSetterImpl(int index) const {
437 static constexpr char fmt[] = R"(
438void {0}{1}{2}Set{3}(MlirOperation op, MlirValue value) {{
439 mlirOperationSetOperand(op, {4}, value);
441)";
442 assert(!className.empty() && "className must be set");
443 assert(!operandNameCapitalized.empty() && "operandName must be set");
444 os << llvm::formatv(
445 fmt,
446 FunctionPrefix, // {0}
448 className, // {2}
450 index // {4}
451 );
452 }
453
454 void genVariadicOperandGetterImpl(int startIdx) const {
455 static constexpr char fmt[] = R"(
456intptr_t {0}{1}{2}Get{3}Count(MlirOperation op) {{
457 intptr_t count = mlirOperationGetNumOperands(op);
458 assert(count >= {4} && "operand count less than start index");
459 return count - {4};
460}
461
462MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index) {{
463 return mlirOperationGetOperand(op, {4} + index);
464}
465)";
466 assert(!className.empty() && "className must be set");
467 assert(!operandNameCapitalized.empty() && "operandName must be set");
468 os << llvm::formatv(
469 fmt,
470 FunctionPrefix, // {0}
472 className, // {2}
474 startIdx // {4}
475 );
476 }
477
478 void genVariadicOperandSetterImpl(int startIdx) const {
479 static constexpr char fmt[] = R"(
480void {0}{1}{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values) {{
481 intptr_t numOperands = mlirOperationGetNumOperands(op);
482 intptr_t startIdx = {4};
483
484 // Validate bounds
485 if (startIdx < 0 || startIdx > numOperands) {{
486 return;
487 }
488 if (count < 0 || count > (std::numeric_limits<intptr_t>::max() - startIdx)) {{
489 return;
490 }
491
492 intptr_t oldCount = numOperands - startIdx;
493 intptr_t newNumOperands = startIdx + count;
494
495 std::vector<MlirValue> newOperands(newNumOperands);
497 // Copy operands before this variadic group
498 for (intptr_t i = 0; i < startIdx; ++i) {{
499 newOperands[i] = mlirOperationGetOperand(op, i);
500 }
501
502 // Copy new variadic operands
503 for (intptr_t i = 0; i < count; ++i) {{
504 newOperands[startIdx + i] = values[i];
505 }
506
507 // Copy operands after this variadic group
508 for (intptr_t i = startIdx + oldCount; i < numOperands; ++i) {{
509 newOperands[i - oldCount + count] = mlirOperationGetOperand(op, i);
511
512 mlirOperationSetOperands(op, newNumOperands, newOperands.data());
513}
514)";
515 assert(!className.empty() && "className must be set");
516 assert(!operandNameCapitalized.empty() && "operandName must be set");
517 os << llvm::formatv(
518 fmt,
519 FunctionPrefix, // {0}
521 className, // {2}
523 startIdx // {4}
524 );
525 }
526
527 void genAttributeGetterImpl(mlir::StringRef attrName) const {
528 static constexpr char fmt[] = R"(
529MlirAttribute {0}{1}{2}Get{3}(MlirOperation op) {{
530 return mlirOperationGetAttributeByName(op, mlirStringRefCreateFromCString("{4}"));
531}
532)";
533 assert(!className.empty() && "className must be set");
534 assert(!attrNameCapitalized.empty() && "attrName must be set");
535 os << llvm::formatv(
536 fmt,
537 FunctionPrefix, // {0}
539 className, // {2}
540 attrNameCapitalized, // {3}
541 attrName // {4}
542 );
543 }
544
545 void genAttributeSetterImpl(mlir::StringRef attrName) const {
546 static constexpr char fmt[] = R"(
547void {0}{1}{2}Set{3}(MlirOperation op, MlirAttribute attr) {{
548 mlirOperationSetAttributeByName(op, mlirStringRefCreateFromCString("{4}"), attr);
549}
550)";
551 assert(!className.empty() && "className must be set");
552 assert(!attrNameCapitalized.empty() && "attrName must be set");
553 os << llvm::formatv(
554 fmt,
555 FunctionPrefix, // {0}
557 className, // {2}
558 attrNameCapitalized, // {3}
559 attrName // {4}
560 );
561 }
562
563 void genResultGetterImpl(int index) const {
564 static constexpr char fmt[] = R"(
565MlirValue {0}{1}{2}Get{3}(MlirOperation op) {{
566 return mlirOperationGetResult(op, {4});
567}
568)";
569 assert(!className.empty() && "className must be set");
570 assert(!resultNameCapitalized.empty() && "resultName must be set");
571 os << llvm::formatv(
572 fmt,
573 FunctionPrefix, // {0}
575 className, // {2}
577 index // {4}
578 );
579 }
580
581 void genVariadicResultGetterImpl(int startIdx) const {
582 static constexpr char fmt[] = R"(
583intptr_t {0}{1}{2}Get{3}Count(MlirOperation op) {{
584 intptr_t count = mlirOperationGetNumResults(op);
585 assert(count >= {4} && "result count less than start index");
586 return count - {4};
587}
588
589MlirValue {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index) {{
590 return mlirOperationGetResult(op, {4} + index);
591}
592)";
593 assert(!className.empty() && "className must be set");
594 assert(!resultNameCapitalized.empty() && "resultName must be set");
595 os << llvm::formatv(
596 fmt,
597 FunctionPrefix, // {0}
599 className, // {2}
601 startIdx // {4}
602 );
603 }
604
605 void genRegionGetterImpl(unsigned index) const {
606 static constexpr char fmt[] = R"(
607MlirRegion {0}{1}{2}Get{3}(MlirOperation op) {{
608 return mlirOperationGetRegion(op, {4});
609}
610)";
611 assert(!className.empty() && "className must be set");
612 assert(!regionNameCapitalized.empty() && "regionName must be set");
613 os << llvm::formatv(
614 fmt,
615 FunctionPrefix, // {0}
617 className, // {2}
619 index // {4}
620 );
621 }
622
623 void genVariadicRegionGetterImpl(unsigned startIdx) const {
624 static constexpr char fmt[] = R"(
625intptr_t {0}{1}{2}Get{3}Count(MlirOperation op) {{
626 intptr_t count = mlirOperationGetNumRegions(op);
627 assert(count >= {4} && "region count less than start index");
628 return count - {4};
629}
630
631MlirRegion {0}{1}{2}Get{3}At(MlirOperation op, intptr_t index) {{
632 return mlirOperationGetRegion(op, {4} + index);
633}
634)";
635 assert(!className.empty() && "className must be set");
636 assert(!regionNameCapitalized.empty() && "regionName must be set");
637 os << llvm::formatv(
638 fmt,
639 FunctionPrefix, // {0}
641 className, // {2}
643 startIdx // {4}
644 );
645 }
646};
647
653static std::string generateCAPIAssignments(const Operator &op) {
654 // Code generated here can use the following variables:
655 // - MlirOpBuilder builder
656 // - MlirLocation location
657 // - MlirOperationState state
658 // - Operand/Attribute/Result parameters per `generateCAPIBuildParams()`
659 struct : GenStringFromOpPieces {
660 void genResultInferred(llvm::raw_ostream &os) override {
661 os << " mlirOperationStateEnableResultTypeInference(&state);\n";
662 }
663 void genResult(
664 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
665 ) override {
666 if (result.isVariadic()) {
667 os << llvm::formatv(
668 " mlirOperationStateAddResults(&state, {0}Size, {0}Types);\n", resultName
669 );
670 } else {
671 os << llvm::formatv(" mlirOperationStateAddResults(&state, 1, &{0}Type);\n", resultName);
672 }
673 }
674 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
675 if (operand.isVariadic()) {
676 os << llvm::formatv(
677 " mlirOperationStateAddOperands(&state, {0}Size, {0});\n", operand.name
678 );
679 } else {
680 os << llvm::formatv(" mlirOperationStateAddOperands(&state, 1, &{0});\n", operand.name);
681 }
682 }
683 void genAttributesPrefix(llvm::raw_ostream &os, const mlir::tblgen::Operator &op) override {
684 os << " MlirContext ctx = mlirOpBuilderGetContext(builder);\n";
685 os << " llvm::SmallVector<MlirNamedAttribute, " << op.getNumAttributes()
686 << "> attributes;\n";
687 }
688 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
689 // The second parameter to `mlirNamedAttributeGet()` must be an "MlirAttribute". However, if
690 // it ends up as "MlirIdentifier", a reinterpret cast is needed. These C structs have the same
691 // layout and the C++ mlir::StringAttr is a subclass of mlir::Attribute so the cast is safe.
692 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
693 std::string attrValue;
694 if (attrType.has_value() && attrType.value() == "MlirIdentifier") {
695 attrValue = "reinterpret_cast<MlirAttribute&>(" + attr.name.str() + ")";
696 } else {
697 attrValue = attr.name.str();
698 }
699
700 os << " if (!mlirAttributeIsNull(" << attrValue << ")) {\n";
701 os << " attributes.push_back(mlirNamedAttributeGet(mlirIdentifierGet(ctx, "
702 << "mlirStringRefCreateFromCString(\"" << attr.name << "\")), " << attrValue << "));\n";
703 os << " }\n";
704 }
705 void genAttributesSuffix(llvm::raw_ostream &os, const mlir::tblgen::Operator &op) override {
706 os << " mlirOperationStateAddAttributes(&state, attributes.size(), attributes.data());\n";
707 }
708 void genRegionsPrefix(llvm::raw_ostream &os, const mlir::tblgen::Operator &op) override {
709 os << " llvm::SmallVector<MlirRegion, " << op.getNumRegions() << "> regions;\n";
710 }
711 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
712 if (region.isVariadic()) {
713 os << llvm::formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", region.name);
714 }
715 os << " regions.push_back(mlirRegionCreate());\n";
716 }
717 void genRegionsSuffix(llvm::raw_ostream &os, const mlir::tblgen::Operator &op) override {
718 os << " mlirOperationStateAddOwnedRegions(&state, regions.size(), regions.data());\n";
719 }
720 } paramStringGenerator;
721 return paramStringGenerator.gen(op);
722}
723
725static bool emitOpCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
726 emitSourceFileHeader("Op C API Definitions", os, records);
727
728 OpImplementationGenerator generator("Operation", os);
729
730 // Capitalize dialect name for function names
731 std::string dialectNameCapitalized = toPascalCase(DialectName);
732
733 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
734 const Operator op(def);
735 const Dialect &dialect = op.getDialect();
736 generator.setDialectAndClassName(&dialect, op.getCppClassName());
737
738 // Generate "Build" function
739 if (GenOpBuild && !op.skipDefaultBuilders()) {
740 std::string assignments = generateCAPIAssignments(op);
741 generator.genOpBuildImpl(op.getOperationName(), generateCAPIBuildParams(op), assignments);
742 }
743
744 // Generate IsA check implementation
745 if (GenIsA) {
746 generator.genIsAImpl();
747 }
748
749 // Generate operand getters and setters
750 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
751 const auto &operand = op.getOperand(i);
752 generator.setOperandName(operand.name);
753 if (operand.isVariadic()) {
755 generator.genVariadicOperandGetterImpl(i);
756 }
758 generator.genVariadicOperandSetterImpl(i);
759 }
760 } else {
762 generator.genOperandGetterImpl(i);
763 }
765 generator.genOperandSetterImpl(i);
766 }
767 }
768 }
769
770 // Generate attribute getters and setters
771 for (const auto &namedAttr : op.getAttributes()) {
772 generator.setAttributeName(namedAttr.name);
774 generator.genAttributeGetterImpl(namedAttr.name);
775 }
777 generator.genAttributeSetterImpl(namedAttr.name);
778 }
779 }
780
781 // Generate result getters
782 if (GenOpResultGetters) {
783 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
784 const auto &result = op.getResult(i);
785 generator.setResultName(result.name, i);
786 if (result.isVariadic()) {
787 generator.genVariadicResultGetterImpl(i);
788 } else {
789 generator.genResultGetterImpl(i);
790 }
791 }
792 }
793
794 // Generate region getters
795 if (GenOpRegionGetters) {
796 for (unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
797 const auto &region = op.getRegion(i);
798 generator.setRegionName(region.name, i);
799 if (region.isVariadic()) {
800 generator.genVariadicRegionGetterImpl(i);
801 } else {
802 generator.genRegionGetterImpl(i);
803 }
804 }
805 }
806
807 // Generate extra class method implementations
809 generator.genExtraMethods(op.getExtraClassDeclaration());
810 }
811 }
812
813 return false;
814}
815
816static mlir::GenRegistration
817 genOpCAPIHeader("gen-op-capi-header", "Generate operation C API header", &emitOpCAPIHeader);
818
819static mlir::GenRegistration
820 genOpCAPIImpl("gen-op-capi-impl", "Generate operation C API implementation", &emitOpCAPIImpl);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Helper struct to generate a string from operation operand, attribute, and result pieces.
mlir::StringRef className
const mlir::tblgen::Dialect * dialect
std::string dialectNameCapitalized
llvm::raw_ostream & os
Generator for common C header file elements.
Generator for common C implementation file elements.
Common between header and implementation generators for operations.
Definition OpCAPIGen.cpp:33
void setAttributeName(mlir::StringRef name)
Definition OpCAPIGen.cpp:35
void setResultName(mlir::StringRef name, int resultIndex)
Definition OpCAPIGen.cpp:36
void setOperandName(mlir::StringRef name)
Definition OpCAPIGen.cpp:34
std::string resultNameCapitalized
Definition OpCAPIGen.cpp:48
std::string attrNameCapitalized
Definition OpCAPIGen.cpp:47
std::string operandNameCapitalized
Definition OpCAPIGen.cpp:46
void setRegionName(mlir::StringRef name, unsigned regionIndex)
Definition OpCAPIGen.cpp:40
std::string regionNameCapitalized
Definition OpCAPIGen.cpp:49
Generator for operation C header files.
Definition OpCAPIGen.cpp:53
void genVariadicRegionGetterDecl() const
void genVariadicResultGetterDecl() const
void genOperandSetterDecl() const
Definition OpCAPIGen.cpp:84
void genVariadicOperandSetterDecl() const
void genAttributeGetterDecl() const
virtual ~OpHeaderGenerator()=default
void genRegionGetterDecl() const
void genAttributeSetterDecl() const
void genVariadicOperandGetterDecl() const
Definition OpCAPIGen.cpp:98
void genResultGetterDecl() const
void genOpBuildDecl(const std::string &params) const
Definition OpCAPIGen.cpp:57
void genOperandGetterDecl() const
Definition OpCAPIGen.cpp:70
Generator for operation C implementation files.
void genOperandGetterImpl(int index) const
void genVariadicOperandGetterImpl(int startIdx) const
void genVariadicOperandSetterImpl(int startIdx) const
void genRegionGetterImpl(unsigned index) const
virtual ~OpImplementationGenerator()=default
void genAttributeGetterImpl(mlir::StringRef attrName) const
void genVariadicRegionGetterImpl(unsigned startIdx) const
void genVariadicResultGetterImpl(int startIdx) const
void genOpBuildImpl(const std::string &operationName, const std::string &params, const std::string &assignments) const
Generate operation "Build" function implementation.
void genResultGetterImpl(int index) const
void genOperandSetterImpl(int index) const
void genAttributeSetterImpl(mlir::StringRef attrName) const