25#include <mlir/Dialect/Arith/IR/Arith.h>
26#include <mlir/Dialect/SCF/IR/SCF.h>
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/ADT/SmallVector.h>
71 WithGeneralBuilder {};
77 const std::tuple<llzk::function::CallOp, llzk::array::CreateArrayOp> NoGeneralBuilder {};
79} OpClassesWithStructTypes;
81template <
typename I,
typename NextOpClass,
typename... OtherOpClasses>
82inline void applyToMoreTypes(I inserter) {
83 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
85template <
typename I>
inline void applyToMoreTypes(I inserter) {}
87inline bool defaultLegalityCheck(
const mlir::TypeConverter &tyConv, mlir::Operation *op) {
89 if (!tyConv.isLegal(op)) {
94 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
95 for (mlir::NamedAttribute n : dictAttr.getValue()) {
96 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
97 mlir::Type t = tyAttr.getValue();
98 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
99 if (!tyConv.isSignatureLegal(funcTy)) {
103 if (!tyConv.isLegal(t)) {
113template <
typename Check>
inline bool runCheck(mlir::Operation *op, Check check) {
114 if (
auto specificOp =
115 llvm::dyn_cast_if_present<
typename llvm::function_traits<Check>::template arg_t<0>>(op)) {
116 return check(specificOp);
125template <
typename OpClass,
typename Rewriter,
typename... Args>
127 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
129 newOp->setDiscardableAttrs(attrs);
136template <
typename OpClass>
139 using mlir::OpConversionPattern<OpClass>::OpConversionPattern;
142 OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
144 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
147 mlir::SmallVector<mlir::Type> newResultTypes;
148 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
149 return op->emitError(
"Could not convert Op result types.");
154 adaptor.getAttributes().empty() ||
156 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a
157 ) { return d.contains(a.getName()); }
161 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
162 for (mlir::NamedAttribute &n : newAttrs) {
163 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
164 if (mlir::Type newType = converter->convertType(t.getValue())) {
165 n.setValue(mlir::TypeAttr::get(newType));
167 return op->emitError().append(
"Could not convert type in attribute: ", t);
173 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
174 mlir::ArrayRef(newAttrs)
176 return mlir::success();
181 :
public mlir::OpConversionPattern<llzk::array::CreateArrayOp> {
186 if (mlir::Type newType = getTypeConverter()->convertType(op.getType())) {
187 return mlir::success();
189 return op->emitError(
"Could not convert Op result type.");
196 mlir::Type newType = getTypeConverter()->convertType(op.getType());
198 llvm::isa<llzk::array::ArrayType>(newType) &&
"CreateArrayOp must produce ArrayType result"
203 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.
getElements()
207 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.
getMapOperands(),
222 mlir::SmallVector<mlir::Type> newResultTypes;
223 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
224 return op->emitError(
"Could not convert Op result types.");
227 rewriter, op, newResultTypes, op.
getCalleeAttr(), adapter.getMapOperands(),
230 return mlir::success();
238template <
typename... AdditionalOpClasses>
240 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
242 mlir::RewritePatternSet patterns(ctx);
243 auto inserter = [&](
auto... opClasses) {
246 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
247 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
251 mlir::populateFunctionOpInterfaceTypeConversionPattern<llzk::function::FuncDefOp>(
254 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
267template <
typename... AdditionalOpClasses,
typename... AdditionalChecks>
269 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks
272 auto inserter = [&](
auto... opClasses) {
273 target.addDynamicallyLegalOp<
decltype(opClasses)...>([&tyConv,
274 &checks...](mlir::Operation *op) {
275 return defaultLegalityCheck(tyConv, op) && (runCheck<AdditionalChecks>(op, checks) && ...);
278 std::apply(inserter, OpClassesWithStructTypes.NoGeneralBuilder);
279 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
280 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::OperandRangeRange getMapOperands()
::mlir::Operation::operand_range getElements()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::SymbolRefAttr getCalleeAttr()
mlir::LogicalResult matchAndRewrite(llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override
void rewrite(llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
mlir::LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
bool isNullOrEmpty(mlir::ArrayAttr a)