LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
OpCAPITestGen.cpp
Go to the documentation of this file.
1//===- OpCAPITestGen.cpp - C API test generator for operations ------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// OpCAPITestGen generates unit tests for the C API operations generated by
11// OpCAPIGen. These are link-time tests that ensure all generated functions
12// compile and link properly, using a pattern where the C API function is
13// wrapped in a conditional that is always false but the compiler still ensures
14// the function within will link correctly.
15//
16// Test Strategy:
17// - Each test creates a dummy operation from a different dialect (arith.constant)
18// - Tests then call the generated C API functions inside an if statement that
19// checks if the dummy op is of the target type (always false)
20// - The compiler still verifies the function signatures and the linker ensures
21// the symbols are defined, even though the code never executes at runtime
22//
23// These tests will catch the following kinds of issues:
24// - Functions declared but not defined (link errors)
25// - Signature mismatches between header and implementation
26// - Missing functions in the build system
27// - ABI compatibility issues
28// - Refactoring breaks
29//
30// However, the following issues will NOT be caught:
31// - Generator logic bugs (if generator is wrong, tests will be wrong too)
32// - Runtime behavior
33// - Semantic correctness
34//
35//===----------------------------------------------------------------------===//
36
37#include <mlir/TableGen/GenInfo.h>
38#include <mlir/TableGen/Operator.h>
39
40#include <llvm/ADT/StringExtras.h>
41#include <llvm/Support/CommandLine.h>
42#include <llvm/Support/FormatVariadic.h>
43#include <llvm/TableGen/Record.h>
44#include <llvm/TableGen/TableGenBackend.h>
45
46#include "CommonCAPIGen.h"
47#include "OpCAPIParamHelper.h"
48
49using namespace mlir;
50using namespace mlir::tblgen;
51
52namespace {
53
59struct OpTestGenerator : public TestGenerator {
62 OpTestGenerator(llvm::raw_ostream &outputStream) : TestGenerator("Operation", outputStream) {}
63
69 virtual std::string genCleanup() const override { return "mlirOperationDestroy"; };
70
73 void genBuildOpTest(const Operator &op) const {
74 static constexpr char fmt[] = R"(
75// This test ensures {0}{1}{2}Build links properly.
76TEST_F({1}OperationLinkTests, {0}{2}_Build) {{
77 // Returns an `arith.constant` op, which will never match the {2} dialect check.
78 auto testOp = createIndexOperation();
79
80 // This condition is always false, so the function is never actually called.
81 // We only verify it compiles and links correctly.
82 if ({0}OperationIsA{1}{2}(testOp)) {{
83 MlirOpBuilder builder = mlirOpBuilderCreate(context);
84 MlirLocation location = mlirLocationUnknownGet(context);
85{3}
86 (void){0}{1}{2}Build(builder, location{4});
87 // No need to destroy builder or op since this code never runs.
88 }
89
90 mlirOperationDestroy(testOp);
91}
92)";
93
94 assert(!className.empty() && "className must be set");
95 os << llvm::formatv(
96 fmt,
97 FunctionPrefix, // {0}
98 dialectNameCapitalized, // {1}
99 className, // {2}
100 generateBuildDummyParams(op), // {3}
101 generateBuildParamList(op) // {4}
102 );
103 }
104
107 void genOperandTests(const Operator &op) const {
108 static constexpr char OperandGetterTest[] = R"(
109TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
110 auto testOp = createIndexOperation();
111
112 if ({0}OperationIsA{1}{2}(testOp)) {{
113 (void){0}{1}{2}Get{3}(testOp);
114 }
115
116 mlirOperationDestroy(testOp);
117}
118)";
119
120 static constexpr char OperandSetterTest[] = R"(
121TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}) {{
122 auto testOp = createIndexOperation();
123
124 if ({0}OperationIsA{1}{2}(testOp)) {{
125 auto dummyValue = mlirOperationGetResult(testOp, 0);
126 {0}{1}{2}Set{3}(testOp, dummyValue);
127 }
128
129 mlirOperationDestroy(testOp);
130}
131)";
132
133 static constexpr char VariadicOperandGetterTest[] = R"(
134TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
135 auto testOp = createIndexOperation();
136
137 if ({0}OperationIsA{1}{2}(testOp)) {{
138 (void){0}{1}{2}Get{3}Count(testOp);
139 }
140
141 mlirOperationDestroy(testOp);
142}
143
144TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
145 auto testOp = createIndexOperation();
146
147 if ({0}OperationIsA{1}{2}(testOp)) {{
148 (void){0}{1}{2}Get{3}At(testOp, 0);
149 }
150
151 mlirOperationDestroy(testOp);
152}
153)";
154
155 static constexpr char VariadicOperandSetterTest[] = R"(
156TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_Variadic) {{
157 auto testOp = createIndexOperation();
158
159 if ({0}OperationIsA{1}{2}(testOp)) {{
160 auto dummyValue = mlirOperationGetResult(testOp, 0);
161 MlirValue values[] = {{dummyValue};
162 {0}{1}{2}Set{3}(testOp, 1, values);
163 }
164
165 mlirOperationDestroy(testOp);
166}
167)";
168 assert(!className.empty() && "className must be set");
169
170 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
171 const auto &operand = op.getOperand(i);
172 std::string capName = toPascalCase(operand.name);
173 if (operand.isVariadic()) {
175 os << llvm::formatv(
176 VariadicOperandGetterTest,
177 FunctionPrefix, // {0}
178 dialectNameCapitalized, // {1}
179 className, // {2}
180 capName // {3}
181 );
182 }
184 os << llvm::formatv(
185 VariadicOperandSetterTest,
186 FunctionPrefix, // {0}
187 dialectNameCapitalized, // {1}
188 className, // {2}
189 capName // {3}
190 );
191 }
192 } else {
194 os << llvm::formatv(
195 OperandGetterTest,
196 FunctionPrefix, // {0}
197 dialectNameCapitalized, // {1}
198 className, // {2}
199 capName // {3}
200 );
201 }
203 os << llvm::formatv(
204 OperandSetterTest,
205 FunctionPrefix, // {0}
206 dialectNameCapitalized, // {1}
207 className, // {2}
208 capName // {3}
209 );
210 }
211 }
212 }
213 }
214
217 void genAttributeTests(const Operator &op) const {
218 static constexpr char AttributeGetterTest[] = R"(
219TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Attr) {{
220 auto testOp = createIndexOperation();
221
222 if ({0}OperationIsA{1}{2}(testOp)) {{
223 (void){0}{1}{2}Get{3}(testOp);
224 }
225
226 mlirOperationDestroy(testOp);
227}
228)";
229
230 static constexpr char AttributeSetterTest[] = R"(
231TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}Attr) {{
232 auto testOp = createIndexOperation();
233
234 if ({0}OperationIsA{1}{2}(testOp)) {{
235 {0}{1}{2}Set{3}(testOp, createIndexAttribute());
236 }
237
238 mlirOperationDestroy(testOp);
239}
240)";
241 assert(!className.empty() && "className must be set");
242
243 for (const auto &namedAttr : op.getAttributes()) {
244 std::string capName = toPascalCase(namedAttr.name);
246 os << llvm::formatv(
247 AttributeGetterTest,
248 FunctionPrefix, // {0}
249 dialectNameCapitalized, // {1}
250 className, // {2}
251 capName // {3}
252 );
253 }
255 os << llvm::formatv(
256 AttributeSetterTest,
257 FunctionPrefix, // {0}
258 dialectNameCapitalized, // {1}
259 className, // {2}
260 capName // {3}
261 );
262 }
263 }
264 }
265
268 void genResultTests(const Operator &op) const {
269 static constexpr char ResultGetterTest[] = R"(
270TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
271 auto testOp = createIndexOperation();
272
273 if ({0}OperationIsA{1}{2}(testOp)) {{
274 (void){0}{1}{2}Get{3}(testOp);
275 }
276
277 mlirOperationDestroy(testOp);
278}
279)";
280
281 static constexpr char VariadicResultGetterTest[] = R"(
282TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
283 auto testOp = createIndexOperation();
284
285 if ({0}OperationIsA{1}{2}(testOp)) {{
286 (void){0}{1}{2}Get{3}Count(testOp);
287 }
288
289 mlirOperationDestroy(testOp);
290}
291
292TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
293 auto testOp = createIndexOperation();
294
295 if ({0}OperationIsA{1}{2}(testOp)) {{
296 (void){0}{1}{2}Get{3}At(testOp, 0);
297 }
298
299 mlirOperationDestroy(testOp);
300}
301)";
302 assert(!className.empty() && "className must be set");
303
304 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
305 const auto &result = op.getResult(i);
306 llvm::StringRef name = result.name;
307 std::string capName = name.empty() ? llvm::formatv("Result{0}", i).str() : toPascalCase(name);
308
309 if (result.isVariadic()) {
310 os << llvm::formatv(
311 VariadicResultGetterTest,
312 FunctionPrefix, // {0}
313 dialectNameCapitalized, // {1}
314 className, // {2}
315 capName // {3}
316 );
317 } else {
318 os << llvm::formatv(
319 ResultGetterTest,
320 FunctionPrefix, // {0}
321 dialectNameCapitalized, // {1}
322 className, // {2}
323 capName // {3}
324 );
325 }
326 }
327 }
328
331 void genRegionTests(const Operator &op) const {
332 static constexpr char RegionGetterTest[] = R"(
333TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Region) {{
334 auto testOp = createIndexOperation();
335
336 if ({0}OperationIsA{1}{2}(testOp)) {{
337 (void){0}{1}{2}Get{3}(testOp);
338 }
339
340 mlirOperationDestroy(testOp);
341}
342)";
343
344 static constexpr char VariadicRegionGetterTest[] = R"(
345TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
346 auto testOp = createIndexOperation();
347
348 if ({0}OperationIsA{1}{2}(testOp)) {{
349 (void){0}{1}{2}Get{3}Count(testOp);
350 }
351
352 mlirOperationDestroy(testOp);
353}
354
355TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
356 auto testOp = createIndexOperation();
357
358 if ({0}OperationIsA{1}{2}(testOp)) {{
359 (void){0}{1}{2}Get{3}At(testOp, 0);
360 }
361
362 mlirOperationDestroy(testOp);
363}
364)";
365 assert(!className.empty() && "className must be set");
366
367 for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
368 const auto &region = op.getRegion(i);
369 llvm::StringRef name = region.name;
370 std::string capName = name.empty() ? llvm::formatv("Region{0}", i).str() : toPascalCase(name);
371
372 if (region.isVariadic()) {
373 os << llvm::formatv(
374 VariadicRegionGetterTest,
375 FunctionPrefix, // {0}
376 dialectNameCapitalized, // {1}
377 className, // {2}
378 capName // {3}
379 );
380 } else {
381 os << llvm::formatv(
382 RegionGetterTest,
383 FunctionPrefix, // {0}
384 dialectNameCapitalized, // {1}
385 className, // {2}
386 capName // {3}
387 );
388 }
389 }
390 }
391
394 void genCompleteRecord(const Operator &op) {
395 const Dialect &defDialect = op.getDialect();
396
397 // Generate for the selected dialect only
398 if (defDialect.getName() != DialectName) {
399 return;
400 }
401
402 this->setDialectAndClassName(&defDialect, op.getCppClassName());
403
404 if (GenIsA) {
405 this->genIsATest();
406 }
407 if (GenOpBuild && !op.skipDefaultBuilders()) {
408 this->genBuildOpTest(op);
409 }
411 this->genOperandTests(op);
412 }
414 this->genAttributeTests(op);
415 }
416 if (GenOpRegionGetters) {
417 this->genRegionTests(op);
418 }
419 if (GenOpResultGetters) {
420 this->genResultTests(op);
421 }
423 this->genExtraMethods(op.getExtraClassDeclaration());
424 }
425 }
426
427private:
431 static std::string generateBuildDummyParams(const Operator &op) {
432 struct : GenStringFromOpPieces {
433 void genHeader(llvm::raw_ostream &os) override {
434 // Declare dummyValue first
435 os << " auto dummyValue = mlirOperationGetResult(testOp, 0);\n";
436 }
437 void genResult(
438 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
439 ) override {
440 if (result.isVariadic()) {
441 os << llvm::formatv(
442 " auto {0}TypeArray = createIndexType();\n"
443 " MlirType {0}Types[] = {{{0}TypeArray};\n"
444 " intptr_t {0}Size = 0;\n",
445 resultName
446 );
447 } else {
448 os << llvm::formatv(" auto {0}Type = createIndexType();\n", resultName);
449 }
450 }
451 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
452 // per `generateParamList()` only need to create something additional in case
453 // of variadic operand, otherwise `dummyValue` is used directly.
454 if (operand.isVariadic()) {
455 os << llvm::formatv(
456 " MlirValue {0}Values[] = {{dummyValue};\n"
457 " intptr_t {0}Size = 0;\n",
458 operand.name
459 );
460 }
461 }
462 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
463 std::string rhs;
464 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
465 if (attrType.has_value() && attrType.value() == "MlirIdentifier") {
466 rhs = "mlirOperationGetName(testOp)";
467 } else {
468 rhs = "createIndexAttribute()";
469 }
470 os << llvm::formatv(" auto {0}Attr = {1};\n", attr.name, rhs);
471 }
472 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
473 if (region.isVariadic()) {
474 os << llvm::formatv(" unsigned {0}Count = 0;\n", region.name);
475 }
476 }
477 } paramsStringGenerator;
478 return paramsStringGenerator.gen(op);
479 }
480
484 static std::string generateBuildParamList(const Operator &op) {
485 struct : GenStringFromOpPieces {
486 void genResult(
487 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
488 ) override {
489 if (result.isVariadic()) {
490 os << llvm::formatv(", {0}Size, {0}Types", resultName);
491 } else {
492 os << llvm::formatv(", {0}Type", resultName);
493 }
494 }
495 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
496 if (operand.isVariadic()) {
497 os << llvm::formatv(", {0}Size, {0}Values", operand.name);
498 } else {
499 os << ", dummyValue";
500 }
501 }
502 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
503 os << llvm::formatv(", {0}Attr", attr.name);
504 }
505 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
506 if (region.isVariadic()) {
507 os << llvm::formatv(", {0}Count", region.name);
508 }
509 }
510 } paramsStringGenerator;
511 return paramsStringGenerator.gen(op);
512 }
513};
514
515} // namespace
516
518static bool emitOpCAPITests(const llvm::RecordKeeper &records, raw_ostream &os) {
519 // Generate file header
520 emitSourceFileHeader("Op C API Tests", os, records);
521
522 // Create generator
523 OpTestGenerator generator(os);
524
525 // Generate test class prologue
526 generator.genTestClassPrologue();
527
528 // Generate tests for each operation
529 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
530 Operator op(def);
531 generator.genCompleteRecord(op);
532 }
533
534 return false;
535}
536
537static mlir::GenRegistration
538 genOpCAPITests("gen-op-capi-tests", "Generate operation C API unit tests", &emitOpCAPITests);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for common test implementation file elements.