LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
EmptyParamListRemovalPass.cpp
Go to the documentation of this file.
1//===-- EmptyParamListRemovalPass.cpp ---------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
19#include <mlir/Dialect/SCF/Transforms/Patterns.h>
20#include <mlir/Transforms/DialectConversion.h>
22// Include the generated base pass class definitions.
24#define GEN_PASS_DEF_EMPTYPARAMLISTREMOVALPASS
26} // namespace llzk::polymorphic
27
28#include "SharedImpl.h"
30#define DEBUG_TYPE "llzk-drop-empty-params"
31
32using namespace mlir;
33using namespace llzk::array;
34using namespace llzk::component;
35using namespace llzk::polymorphic::detail;
36
37namespace {
38
39bool hasEmptyParamList(StructType t) {
40 if (ArrayAttr paramList = t.getParams()) {
41 if (paramList.empty()) {
42 return true;
43 }
44 }
45 return false;
46}
47
48/// Convert StructType with empty parameter list to one with no parameters.
49class EmptyParamListStructTypeConverter : public TypeConverter {
50public:
51 EmptyParamListStructTypeConverter() : TypeConverter() {
52
53 addConversion([](Type inputTy) { return inputTy; });
54
55 addConversion([](StructType inputTy) -> StructType {
56 return hasEmptyParamList(inputTy) ? StructType::get(inputTy.getNameRef()) : inputTy;
57 });
58
59 addConversion([this](ArrayType inputTy) {
60 // Recursively convert element type
61 return ArrayType::get(
62 this->convertType(inputTy.getElementType()), inputTy.getDimensionSizes()
63 );
64 });
65 }
66};
67
68class CallOpTypeReplacePattern : public OpConversionPattern<StructDefOp> {
69public:
70 using OpConversionPattern<StructDefOp>::OpConversionPattern;
71
72 LogicalResult match(StructDefOp op) const override {
73 return success(hasEmptyParamList(op.getType()));
74 }
75
76 void
77 rewrite(StructDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
78 rewriter.modifyOpInPlace(op, [&op]() { op.setConstParamsAttr(nullptr); });
79 }
80};
81
84class EmptyParamRemovalPass
85 : public llzk::polymorphic::impl::EmptyParamListRemovalPassBase<EmptyParamRemovalPass> {
86
87 void runOnOperation() override {
88 ModuleOp modOp = getOperation();
89 MLIRContext *ctx = modOp.getContext();
90 EmptyParamListStructTypeConverter tyConv;
91 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
92 // Mark StructDefOp with empty parameter list as illegal
93 target.addDynamicallyLegalOp<StructDefOp>([](StructDefOp op) {
94 return !hasEmptyParamList(op.getType());
95 });
96 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
97 patterns.add<CallOpTypeReplacePattern>(tyConv, ctx);
98 if (failed(applyFullConversion(modOp, target, std::move(patterns)))) {
99 signalPassFailure();
100 }
101 }
102};
103
104} // namespace
105
107 return std::make_unique<EmptyParamRemovalPass>();
108};
Common private implementation for poly dialect passes.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
Definition Ops.cpp:142
void setConstParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.cpp.inc:1975
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:69
::mlir::ArrayAttr getParams() const
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:239
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:268
std::unique_ptr< mlir::Pass > createEmptyParamListRemoval()