25#include <mlir/Dialect/Arith/IR/Arith.h>
26#include <mlir/Dialect/SCF/IR/SCF.h>
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/ADT/SmallVector.h>
45static struct OpClassesWithStructTypes {
71 WithGeneralBuilder {};
77 const std::tuple<llzk::function::CallOp, llzk::array::CreateArrayOp> NoGeneralBuilder {};
79} OpClassesWithStructTypes;
81template <
typename I,
typename NextOpClass,
typename... OtherOpClasses>
82inline void applyToMoreTypes(I inserter) {
83 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
85template <
typename I>
inline void applyToMoreTypes(I inserter) {}
87inline bool defaultLegalityCheck(
const mlir::TypeConverter &tyConv, mlir::Operation *op) {
89 if (!tyConv.isLegal(op)) {
94 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
95 for (mlir::NamedAttribute n : dictAttr.getValue()) {
96 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
97 mlir::Type t = tyAttr.getValue();
98 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
99 if (!tyConv.isSignatureLegal(funcTy)) {
103 if (!tyConv.isLegal(t)) {
113template <
typename Check>
inline bool runCheck(mlir::Operation *op, Check check) {
114 if (
auto specificOp =
115 llvm::dyn_cast_if_present<
typename llvm::function_traits<Check>::template arg_t<0>>(op)) {
116 return check(specificOp);
125template <
typename OpClass,
typename Rewriter,
typename... Args>
127 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
129 newOp->setDiscardableAttrs(attrs);
136template <
typename OpClass>
139 using mlir::OpConversionPattern<OpClass>::OpConversionPattern;
142 OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
144 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
147 mlir::SmallVector<mlir::Type> newResultTypes;
148 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
149 return op->emitError(
"Could not convert Op result types.");
154 adaptor.getAttributes().empty() ||
156 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
157 return d.contains(a.getName());
162 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
163 for (mlir::NamedAttribute &n : newAttrs) {
164 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
165 if (mlir::Type newType = converter->convertType(t.getValue())) {
166 n.setValue(mlir::TypeAttr::get(newType));
168 return op->emitError().append(
"Could not convert type in attribute: ", t);
174 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
175 mlir::ArrayRef(newAttrs)
177 return mlir::success();
182 :
public mlir::OpConversionPattern<llzk::array::CreateArrayOp> {
187 if (mlir::Type newType = getTypeConverter()->convertType(op.getType())) {
188 return mlir::success();
190 return op->emitError(
"Could not convert Op result type.");
197 mlir::Type newType = getTypeConverter()->convertType(op.getType());
199 llvm::isa<llzk::array::ArrayType>(newType) &&
"CreateArrayOp must produce ArrayType result"
204 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.
getElements()
208 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.
getMapOperands(),
223 mlir::SmallVector<mlir::Type> newResultTypes;
224 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
225 return op->emitError(
"Could not convert Op result types.");
228 rewriter, op, newResultTypes, op.
getCalleeAttr(), adapter.getMapOperands(),
231 return mlir::success();
239template <
typename... AdditionalOpClasses>
241 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
243 mlir::RewritePatternSet patterns(ctx);
244 auto inserter = [&](
auto... opClasses) {
247 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
248 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
252 mlir::populateFunctionOpInterfaceTypeConversionPattern<llzk::function::FuncDefOp>(
255 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
268template <
typename... AdditionalOpClasses,
typename... AdditionalChecks>
270 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks
273 auto inserter = [&](
auto... opClasses) {
274 target.addDynamicallyLegalOp<
decltype(opClasses)...>([&tyConv,
275 &checks...](mlir::Operation *op) {
276 return defaultLegalityCheck(tyConv, op) && (runCheck<AdditionalChecks>(op, checks) && ...);
279 std::apply(inserter, OpClassesWithStructTypes.NoGeneralBuilder);
280 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
281 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::OperandRangeRange getMapOperands()
::mlir::Operation::operand_range getElements()
::mlir::SymbolRefAttr getCalleeAttr()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
mlir::LogicalResult matchAndRewrite(llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override
void rewrite(llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
mlir::LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
bool isNullOrEmpty(mlir::ArrayAttr a)