LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
CommonAttrOrTypeCAPIGen.h
Go to the documentation of this file.
1//===- CommonAttrOrTypeCAPIGen.h ------------------------------------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Common utilities shared between Attr and Type CAPI generators
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15
16#include <mlir/TableGen/AttrOrTypeDef.h>
17
18#include "CommonCAPIGen.h"
19
25 using HeaderGenerator::HeaderGenerator;
26 virtual ~AttrOrTypeHeaderGenerator() = default;
27
30 void setParamName(mlir::StringRef name) {
31 this->paramName = name;
33 }
34
36 virtual void genParameterGetterDecl(mlir::StringRef cppType) const {
37 static constexpr char fmt[] = R"(
38/* Get '{5}' parameter from a {6}::{3} {1}. */
39MLIR_CAPI_EXPORTED {7} {0}{2}{3}Get{4}(Mlir{1});
40)";
41 assert(dialect && "Dialect must be set");
42 assert(!paramName.empty() && "paramName must be set");
43 os << llvm::formatv(
44 fmt,
45 FunctionPrefix, // {0}
46 kind, // {1}
48 className, // {3}
50 paramName, // {5}
51 dialect->getCppNamespace(), // {6}
52 mapCppTypeToCapiType(cppType) // {7}
53 );
54 }
55
57 virtual void genArrayRefParameterGetterDecls(mlir::StringRef cppType) const {
58 static constexpr char fmt[] = R"(
59/* Get count of '{5}' parameter from a {6}::{3} {1}. */
60MLIR_CAPI_EXPORTED intptr_t {0}{2}{3}Get{4}Count(Mlir{1});
61
62/* Get element at index from '{5}' parameter from a {6}::{3} {1}. */
63MLIR_CAPI_EXPORTED {7} {0}{2}{3}Get{4}At(Mlir{1}, intptr_t pos);
64)";
65 assert(dialect && "Dialect must be set");
66 assert(!paramName.empty() && "paramName must be set");
67 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
68 os << llvm::formatv(
69 fmt,
70 FunctionPrefix, // {0}
71 kind, // {1}
73 className, // {3}
75 paramName, // {5}
76 dialect->getCppNamespace(), // {6}
77 mapCppTypeToCapiType(cppElemType) // {7}
78 );
79 }
80
82 virtual void genDefaultGetBuilderDecl(const mlir::tblgen::AttrOrTypeDef &def) const {
83 static constexpr char fmt[] = R"(
84/* Create a {5}::{3} {1} with the given parameters. */
85MLIR_CAPI_EXPORTED Mlir{1} {0}{2}{3}Get(MlirContext ctx{4});
86)";
87 assert(dialect && "Dialect must be set");
88
89 // Use raw_string_ostream for efficient string building of parameter list
90 std::string paramListBuffer;
91 llvm::raw_string_ostream paramListStream(paramListBuffer);
92 for (const auto &param : def.getParameters()) {
93 mlir::StringRef cppType = param.getCppType();
94 if (isArrayRefType(cppType)) {
95 // For ArrayRef parameters, use intptr_t for count and pointer to element type
96 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
97 paramListStream << ", intptr_t " << param.getName() << "Count, "
98 << mapCppTypeToCapiType(cppElemType) << " *" << param.getName();
99 } else {
100 paramListStream << ", " << mapCppTypeToCapiType(cppType) << ' ' << param.getName();
101 }
102 }
104 os << llvm::formatv(
105 fmt,
106 FunctionPrefix, // {0}
107 kind, // {1}
109 className, // {3}
110 paramListBuffer, // {4}
111 dialect->getCppNamespace() // {5}
112 );
113 }
114
115 void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef def) {
116 const mlir::tblgen::Dialect &defDialect = def.getDialect();
117
118 // Generate for the selected dialect only
119 if (defDialect.getName() != DialectName) {
120 return;
121 }
122
123 this->setDialectAndClassName(&defDialect, def.getCppClassName());
124
125 // Generate IsA check
126 if (GenIsA) {
127 this->genIsADecl();
128 }
129
130 // Generate default Get builder if not skipped
131 if (GenTypeOrAttrGet && !def.skipDefaultBuilders()) {
132 this->genDefaultGetBuilderDecl(def);
133 }
134
135 // Generate parameter getters
137 for (const auto &param : def.getParameters()) {
138 this->setParamName(param.getName());
139 mlir::StringRef cppType = param.getCppType();
140 if (isArrayRefType(cppType)) {
141 this->genArrayRefParameterGetterDecls(cppType);
142 } else {
143 this->genParameterGetterDecl(cppType);
144 }
145 }
148 // Generate extra class method declarations
150 std::optional<mlir::StringRef> extraDecls = def.getExtraDecls();
151 if (extraDecls.has_value()) {
152 this->genExtraMethods(extraDecls.value());
153 }
154 }
155 }
156
157protected:
158 mlir::StringRef paramName;
159 std::string paramNameCapitalized;
160};
161
167 using ImplementationGenerator::ImplementationGenerator;
169
172 void setParamName(mlir::StringRef name) {
173 this->paramName = name;
175 }
176
177 virtual void genPrologue() const {
178 os << R"(
179#include <mlir/CAPI/IR.h>
180#include <mlir/CAPI/Support.h>
181#include <llvm/ADT/TypeSwitch.h>
182
183using namespace mlir;
184using namespace llvm;
185)";
186 }
187
188 virtual void genArrayRefParameterImpls(mlir::StringRef cppType) const {
189 static constexpr char fmt[] = R"(
190intptr_t {0}{2}{3}Get{4}Count(Mlir{1} inp) {{
191 return static_cast<intptr_t>(llvm::cast<{3}>(unwrap(inp)).get{4}().size());
192}
193
194{5} {0}{2}{3}Get{4}At(Mlir{1} inp, intptr_t pos) {{
195 return {6}(llvm::cast<{3}>(unwrap(inp)).get{4}()[pos]);
196}
197 )";
198 assert(!className.empty() && "className must be set");
199 assert(!paramName.empty() && "paramName must be set");
200 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
201 os << llvm::formatv(
202 fmt,
203 FunctionPrefix, // {0}
204 kind, // {1}
206 className, // {3}
208 mapCppTypeToCapiType(cppElemType), // {5}
209 isPrimitiveType(cppElemType) ? "" : "wrap" // {6}
210 );
211 }
212
213 virtual void genParameterGetterImpl(mlir::StringRef cppType) const {
214 static constexpr char fmt[] = R"(
215{5} {0}{2}{3}Get{4}(Mlir{1} inp) {{
216 return {6}(llvm::cast<{3}>(unwrap(inp)).get{4}());
217}
218 )";
219 assert(!className.empty() && "className must be set");
220 assert(!paramName.empty() && "paramName must be set");
221 os << llvm::formatv(
222 fmt,
223 FunctionPrefix, // {0}
224 kind, // {1}
226 className, // {3}
228 mapCppTypeToCapiType(cppType), // {5}
229 isPrimitiveType(cppType) ? "" : "wrap" // {6}
230 );
231 }
232
234 virtual void genDefaultGetBuilderImpl(const mlir::tblgen::AttrOrTypeDef &def) const {
235 static constexpr char fmt[] = R"(
236Mlir{1} {0}{2}{3}Get(MlirContext ctx{4}) {{
237 return wrap({3}::get(unwrap(ctx){5}));
238}
239 )";
240 assert(!className.empty() && "className must be set");
241
242 // Use raw_string_ostream for efficient string building
243 std::string paramListBuffer;
244 std::string argListBuffer;
245 llvm::raw_string_ostream paramListStream(paramListBuffer);
246 llvm::raw_string_ostream argListStream(argListBuffer);
247
248 for (const auto &param : def.getParameters()) {
249 mlir::StringRef pName = param.getName();
250 mlir::StringRef cppType = param.getCppType();
251 if (isArrayRefType(cppType)) {
252 // For ArrayRef parameters, convert from pointer + count to ArrayRef
253 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
254 std::string capiElemType = mapCppTypeToCapiType(cppElemType);
255 paramListStream << ", intptr_t " << pName << "Count, " << capiElemType << " *" << pName;
256
257 // In the call, we need to convert back to ArrayRef. Check if elements need unwrapping.
258 if (isPrimitiveType(cppElemType)) {
259 argListStream << ", ::llvm::ArrayRef<" << capiElemType << ">(" << pName << ", " << pName
260 << "Count)";
261 } else {
262 argListStream << ", ::llvm::ArrayRef<" << capiElemType << ">(unwrapList(" << pName << ", "
263 << pName << "Count))";
264 }
265 } else {
266 std::string capiType = mapCppTypeToCapiType(cppType);
267 paramListStream << ", " << capiType << ' ' << pName;
268
269 // Add unwrapping if needed
270 argListStream << ", ";
271 if (isPrimitiveType(cppType)) {
272 argListStream << pName;
273 } else if (capiType == "MlirAttribute" || capiType == "MlirType") {
274 // Needs additional cast to the specific attribute/type class
275 argListStream << "::llvm::cast<" << cppType << ">(unwrap(" << pName << "))";
276 } else {
277 // Any other cases, just use an "unwrap" function
278 argListStream << "unwrap(" << pName << ")";
279 }
280 }
281 }
282
283 os << llvm::formatv(
284 fmt,
285 FunctionPrefix, // {0}
286 kind, // {1}
288 className, // {3}
289 paramListBuffer, // {4}
290 argListBuffer // {5}
291 );
292 }
293
294 void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef def) {
295 const mlir::tblgen::Dialect &defDialect = def.getDialect();
296
297 // Generate for the selected dialect only
298 if (defDialect.getName() != DialectName) {
299 return;
300 }
301
302 this->setDialectAndClassName(&defDialect, def.getCppClassName());
303
304 // Generate IsA check implementation
305 if (GenIsA) {
306 this->genIsAImpl();
307 }
308
309 // Generate default Get builder implementation if not skipped
310 if (GenTypeOrAttrGet && !def.skipDefaultBuilders()) {
311 this->genDefaultGetBuilderImpl(def);
312 }
313
314 // Generate parameter getter implementations
316 for (const auto &param : def.getParameters()) {
317 this->setParamName(param.getName());
318 mlir::StringRef cppType = param.getCppType();
319 if (isArrayRefType(cppType)) {
320 this->genArrayRefParameterImpls(cppType);
321 } else {
322 this->genParameterGetterImpl(cppType);
323 }
324 }
325 }
326
327 // Generate extra class method implementations
329 std::optional<mlir::StringRef> extraDecls = def.getExtraDecls();
330 if (extraDecls.has_value()) {
331 this->genExtraMethods(extraDecls.value());
332 }
333 }
334 }
335
336protected:
337 mlir::StringRef paramName;
338 std::string paramNameCapitalized;
339};
std::string mapCppTypeToCapiType(StringRef cppType)
mlir::StringRef extractArrayRefElementType(mlir::StringRef cppType)
Extract element type from ArrayRef<...>
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for attribute/type C header files.
void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef def)
virtual void genArrayRefParameterGetterDecls(mlir::StringRef cppType) const
Generate accessor function for ArrayRef parameter elements.
void setParamName(mlir::StringRef name)
Set the parameter name for code generation.
virtual ~AttrOrTypeHeaderGenerator()=default
virtual void genDefaultGetBuilderDecl(const mlir::tblgen::AttrOrTypeDef &def) const
Generate default Get builder declaration.
virtual void genParameterGetterDecl(mlir::StringRef cppType) const
Generate regular getter for non-ArrayRef type parameter.
Generator for attribute/type C implementation files.
void setParamName(mlir::StringRef name)
Set the parameter name for code generation.
virtual void genDefaultGetBuilderImpl(const mlir::tblgen::AttrOrTypeDef &def) const
Generate default Get builder implementation.
virtual ~AttrOrTypeImplementationGenerator()=default
virtual void genArrayRefParameterImpls(mlir::StringRef cppType) const
void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef def)
virtual void genParameterGetterImpl(mlir::StringRef cppType) const
virtual void setDialectAndClassName(const mlir::tblgen::Dialect *d, mlir::StringRef cppClassName)
Set the dialect and class name for code generation.
mlir::StringRef className
virtual void genExtraMethods(mlir::StringRef extraDecl) const
Generate code for extra methods from an extraClassDeclaration
const mlir::tblgen::Dialect * dialect
std::string dialectNameCapitalized
llvm::raw_ostream & os
std::string kind
Generator for common C header file elements.
virtual void genIsADecl() const
Generator for common C implementation file elements.
virtual void genIsAImpl() const