37#include <mlir/TableGen/GenInfo.h>
38#include <mlir/TableGen/Operator.h>
40#include <llvm/ADT/StringExtras.h>
41#include <llvm/Support/CommandLine.h>
42#include <llvm/Support/FormatVariadic.h>
43#include <llvm/TableGen/Record.h>
44#include <llvm/TableGen/TableGenBackend.h>
50using namespace mlir::tblgen;
62 OpTestGenerator(llvm::raw_ostream &outputStream) : TestGenerator(
"Operation", outputStream) {}
69 virtual std::string genCleanup()
const override {
return "mlirOperationDestroy"; };
73 void genBuildOpTest(
const Operator &op)
const {
74 static constexpr char fmt[] = R
"(
75// This test ensures {0}{1}{2}Build links properly.
76TEST_F({1}OperationLinkTests, {0}{2}_Build) {{
77 // Returns an `arith.constant` op, which will never match the {2} dialect check.
78 auto testOp = createIndexOperation();
80 // This condition is always false, so the function is never actually called.
81 // We only verify it compiles and links correctly.
82 if ({0}OperationIsA{1}{2}(testOp)) {{
83 MlirOpBuilder builder = mlirOpBuilderCreate(context);
84 MlirLocation location = mlirLocationUnknownGet(context);
86 (void){0}{1}{2}Build(builder, location{4});
87 // No need to destroy builder or op since this code never runs.
90 mlirOperationDestroy(testOp);
94 assert(!className.empty() && "className must be set");
98 dialectNameCapitalized,
100 generateBuildDummyParams(op),
101 generateBuildParamList(op)
107 void genOperandTests(
const Operator &op)
const {
108 static constexpr char OperandGetterTest[] = R
"(
109TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
110 auto testOp = createIndexOperation();
112 if ({0}OperationIsA{1}{2}(testOp)) {{
113 (void){0}{1}{2}Get{3}(testOp);
116 mlirOperationDestroy(testOp);
120 static constexpr char OperandSetterTest[] = R
"(
121TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}) {{
122 auto testOp = createIndexOperation();
124 if ({0}OperationIsA{1}{2}(testOp)) {{
125 auto dummyValue = mlirOperationGetResult(testOp, 0);
126 {0}{1}{2}Set{3}(testOp, dummyValue);
129 mlirOperationDestroy(testOp);
133 static constexpr char VariadicOperandGetterTest[] = R
"(
134TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
135 auto testOp = createIndexOperation();
137 if ({0}OperationIsA{1}{2}(testOp)) {{
138 (void){0}{1}{2}Get{3}Count(testOp);
141 mlirOperationDestroy(testOp);
144TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
145 auto testOp = createIndexOperation();
147 if ({0}OperationIsA{1}{2}(testOp)) {{
148 (void){0}{1}{2}Get{3}At(testOp, 0);
151 mlirOperationDestroy(testOp);
155 static constexpr char VariadicOperandSetterTest[] = R
"(
156TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_Variadic) {{
157 auto testOp = createIndexOperation();
159 if ({0}OperationIsA{1}{2}(testOp)) {{
160 auto dummyValue = mlirOperationGetResult(testOp, 0);
161 MlirValue values[] = {{dummyValue};
162 {0}{1}{2}Set{3}(testOp, 1, values);
165 mlirOperationDestroy(testOp);
168 assert(!className.empty() && "className must be set");
170 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
171 const auto &operand = op.getOperand(i);
173 if (operand.isVariadic()) {
176 VariadicOperandGetterTest,
178 dialectNameCapitalized,
185 VariadicOperandSetterTest,
187 dialectNameCapitalized,
197 dialectNameCapitalized,
206 dialectNameCapitalized,
217 void genAttributeTests(
const Operator &op)
const {
218 static constexpr char AttributeGetterTest[] = R
"(
219TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Attr) {{
220 auto testOp = createIndexOperation();
222 if ({0}OperationIsA{1}{2}(testOp)) {{
223 (void){0}{1}{2}Get{3}(testOp);
226 mlirOperationDestroy(testOp);
230 static constexpr char AttributeSetterTest[] = R
"(
231TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}Attr) {{
232 auto testOp = createIndexOperation();
234 if ({0}OperationIsA{1}{2}(testOp)) {{
235 {0}{1}{2}Set{3}(testOp, createIndexAttribute());
238 mlirOperationDestroy(testOp);
241 assert(!className.empty() && "className must be set");
243 for (
const auto &namedAttr : op.getAttributes()) {
249 dialectNameCapitalized,
258 dialectNameCapitalized,
268 void genResultTests(
const Operator &op)
const {
269 static constexpr char ResultGetterTest[] = R
"(
270TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
271 auto testOp = createIndexOperation();
273 if ({0}OperationIsA{1}{2}(testOp)) {{
274 (void){0}{1}{2}Get{3}(testOp);
277 mlirOperationDestroy(testOp);
281 static constexpr char VariadicResultGetterTest[] = R
"(
282TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
283 auto testOp = createIndexOperation();
285 if ({0}OperationIsA{1}{2}(testOp)) {{
286 (void){0}{1}{2}Get{3}Count(testOp);
289 mlirOperationDestroy(testOp);
292TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
293 auto testOp = createIndexOperation();
295 if ({0}OperationIsA{1}{2}(testOp)) {{
296 (void){0}{1}{2}Get{3}At(testOp, 0);
299 mlirOperationDestroy(testOp);
302 assert(!className.empty() && "className must be set");
304 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
305 const auto &result = op.getResult(i);
306 llvm::StringRef name = result.name;
307 std::string capName = name.empty() ? llvm::formatv(
"Result{0}", i).str() :
toPascalCase(name);
309 if (result.isVariadic()) {
311 VariadicResultGetterTest,
313 dialectNameCapitalized,
321 dialectNameCapitalized,
331 void genRegionTests(
const Operator &op)
const {
332 static constexpr char RegionGetterTest[] = R
"(
333TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Region) {{
334 auto testOp = createIndexOperation();
336 if ({0}OperationIsA{1}{2}(testOp)) {{
337 (void){0}{1}{2}Get{3}(testOp);
340 mlirOperationDestroy(testOp);
344 static constexpr char VariadicRegionGetterTest[] = R
"(
345TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
346 auto testOp = createIndexOperation();
348 if ({0}OperationIsA{1}{2}(testOp)) {{
349 (void){0}{1}{2}Get{3}Count(testOp);
352 mlirOperationDestroy(testOp);
355TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
356 auto testOp = createIndexOperation();
358 if ({0}OperationIsA{1}{2}(testOp)) {{
359 (void){0}{1}{2}Get{3}At(testOp, 0);
362 mlirOperationDestroy(testOp);
365 assert(!className.empty() && "className must be set");
367 for (
int i = 0, e = op.getNumRegions(); i < e; ++i) {
368 const auto ®ion = op.getRegion(i);
369 llvm::StringRef name = region.name;
370 std::string capName = name.empty() ? llvm::formatv(
"Region{0}", i).str() :
toPascalCase(name);
372 if (region.isVariadic()) {
374 VariadicRegionGetterTest,
376 dialectNameCapitalized,
384 dialectNameCapitalized,
394 void genCompleteRecord(
const Operator &op) {
395 const Dialect &defDialect = op.getDialect();
402 this->setDialectAndClassName(&defDialect, op.getCppClassName());
407 if (
GenOpBuild && !op.skipDefaultBuilders()) {
408 this->genBuildOpTest(op);
411 this->genOperandTests(op);
414 this->genAttributeTests(op);
417 this->genRegionTests(op);
420 this->genResultTests(op);
423 this->genExtraMethods(op.getExtraClassDeclaration());
431 static std::string generateBuildDummyParams(
const Operator &op) {
432 struct : GenStringFromOpPieces {
433 void genHeader(llvm::raw_ostream &os)
override {
435 os <<
" auto dummyValue = mlirOperationGetResult(testOp, 0);\n";
438 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
440 if (result.isVariadic()) {
442 " auto {0}TypeArray = createIndexType();\n"
443 " MlirType {0}Types[] = {{{0}TypeArray};\n"
444 " intptr_t {0}Size = 0;\n",
448 os << llvm::formatv(
" auto {0}Type = createIndexType();\n", resultName);
451 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
454 if (operand.isVariadic()) {
456 " MlirValue {0}Values[] = {{dummyValue};\n"
457 " intptr_t {0}Size = 0;\n",
462 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
465 if (attrType.has_value() && attrType.value() ==
"MlirIdentifier") {
466 rhs =
"mlirOperationGetName(testOp)";
468 rhs =
"createIndexAttribute()";
470 os << llvm::formatv(
" auto {0}Attr = {1};\n", attr.name, rhs);
472 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
473 if (region.isVariadic()) {
474 os << llvm::formatv(
" unsigned {0}Count = 0;\n", region.name);
477 } paramsStringGenerator;
478 return paramsStringGenerator.gen(op);
484 static std::string generateBuildParamList(
const Operator &op) {
485 struct : GenStringFromOpPieces {
487 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
489 if (result.isVariadic()) {
490 os << llvm::formatv(
", {0}Size, {0}Types", resultName);
492 os << llvm::formatv(
", {0}Type", resultName);
495 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
496 if (operand.isVariadic()) {
497 os << llvm::formatv(
", {0}Size, {0}Values", operand.name);
499 os <<
", dummyValue";
502 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
503 os << llvm::formatv(
", {0}Attr", attr.name);
505 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
506 if (region.isVariadic()) {
507 os << llvm::formatv(
", {0}Count", region.name);
510 } paramsStringGenerator;
511 return paramsStringGenerator.gen(op);
518static bool emitOpCAPITests(
const llvm::RecordKeeper &records, raw_ostream &os) {
520 emitSourceFileHeader(
"Op C API Tests", os, records);
523 OpTestGenerator generator(os);
526 generator.genTestClassPrologue();
529 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
531 generator.genCompleteRecord(op);
537static mlir::GenRegistration
538 genOpCAPITests(
"gen-op-capi-tests",
"Generate operation C API unit tests", &emitOpCAPITests);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for common test implementation file elements.