LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SharedImpl.h
Go to the documentation of this file.
1//===-- SharedImpl.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
15#pragma once
16
24
25#include <mlir/Dialect/Arith/IR/Arith.h>
26#include <mlir/Dialect/SCF/IR/SCF.h>
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
34
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/ADT/SmallVector.h>
37
38#include <tuple>
39
41
42namespace {
43
45static struct OpClassesWithStructTypes {
46
49 const std::tuple<
50 // clang-format off
69 // clang-format on
70 >
71 WithGeneralBuilder {};
72
77 const std::tuple<llzk::function::CallOp, llzk::array::CreateArrayOp> NoGeneralBuilder {};
78
79} OpClassesWithStructTypes;
80
81template <typename I, typename NextOpClass, typename... OtherOpClasses>
82inline void applyToMoreTypes(I inserter) {
83 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
84}
85template <typename I> inline void applyToMoreTypes(I inserter) {}
86
87inline bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op) {
88 // Check operand types and result types
89 if (!tyConv.isLegal(op)) {
90 return false;
91 }
92 // Check type attributes
93 // Extend lifetime of temporary to suppress warnings.
94 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
95 for (mlir::NamedAttribute n : dictAttr.getValue()) {
96 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
97 mlir::Type t = tyAttr.getValue();
98 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
99 if (!tyConv.isSignatureLegal(funcTy)) {
100 return false;
101 }
102 } else {
103 if (!tyConv.isLegal(t)) {
104 return false;
105 }
106 }
107 }
108 }
109 return true;
110}
111
112// Default to true if the check is not for that particular operation type.
113template <typename Check> inline bool runCheck(mlir::Operation *op, Check check) {
114 if (auto specificOp =
115 llvm::dyn_cast_if_present<typename llvm::function_traits<Check>::template arg_t<0>>(op)) {
116 return check(specificOp);
117 }
118 return true;
119}
120
121} // namespace
122
125template <typename OpClass, typename Rewriter, typename... Args>
126inline OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args) {
127 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
128 OpClass newOp = rewriter.template replaceOpWithNewOp<OpClass>(op, std::forward<Args>(args)...);
129 newOp->setDiscardableAttrs(attrs);
130 return newOp;
131}
132
133// NOTE: This pattern will produce a compile error if `OpClass` does not define the general
134// `build(OpBuilder&, OperationState&, TypeRange, ValueRange, ArrayRef<NamedAttribute>)` function
135// because that function is required by the `replaceOpWithNewOp()` call.
136template <typename OpClass>
137class GeneralTypeReplacePattern : public mlir::OpConversionPattern<OpClass> {
138public:
139 using mlir::OpConversionPattern<OpClass>::OpConversionPattern;
140
141 mlir::LogicalResult matchAndRewrite(
142 OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
143 ) const override {
144 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
145 assert(converter);
146 // Convert result types
147 mlir::SmallVector<mlir::Type> newResultTypes;
148 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
149 return op->emitError("Could not convert Op result types.");
150 }
151 // ASSERT: 'adaptor.getAttributes()' is empty or subset of 'op->getAttrDictionary()' so the
152 // former can be ignored without losing anything.
153 assert(
154 adaptor.getAttributes().empty() ||
155 llvm::all_of(
156 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
157 return d.contains(a.getName());
158 }
159 )
160 );
161 // Convert any TypeAttr in the attribute list.
162 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
163 for (mlir::NamedAttribute &n : newAttrs) {
164 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
165 if (mlir::Type newType = converter->convertType(t.getValue())) {
166 n.setValue(mlir::TypeAttr::get(newType));
167 } else {
168 return op->emitError().append("Could not convert type in attribute: ", t);
169 }
170 }
171 }
172 // Build a new Op in place of the current one
174 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
175 mlir::ArrayRef(newAttrs)
176 );
177 return mlir::success();
178 }
179};
180
182 : public mlir::OpConversionPattern<llzk::array::CreateArrayOp> {
183public:
184 using mlir::OpConversionPattern<llzk::array::CreateArrayOp>::OpConversionPattern;
185
186 mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override {
187 if (mlir::Type newType = getTypeConverter()->convertType(op.getType())) {
188 return mlir::success();
189 } else {
190 return op->emitError("Could not convert Op result type.");
191 }
192 }
193
195 llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
196 ) const override {
197 mlir::Type newType = getTypeConverter()->convertType(op.getType());
198 assert(
199 llvm::isa<llzk::array::ArrayType>(newType) && "CreateArrayOp must produce ArrayType result"
200 );
201 mlir::DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr();
202 if (isNullOrEmpty(numDimsPerMap)) {
204 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.getElements()
205 );
206 } else {
208 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.getMapOperands(),
209 numDimsPerMap
210 );
211 }
212 }
213};
214
215class CallOpClassReplacePattern : public mlir::OpConversionPattern<llzk::function::CallOp> {
216public:
217 using mlir::OpConversionPattern<llzk::function::CallOp>::OpConversionPattern;
218
219 mlir::LogicalResult matchAndRewrite(
220 llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
221 ) const override {
222 // Convert the result types of the CallOp
223 mlir::SmallVector<mlir::Type> newResultTypes;
224 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
225 return op->emitError("Could not convert Op result types.");
226 }
228 rewriter, op, newResultTypes, op.getCalleeAttr(), adapter.getMapOperands(),
229 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
230 );
231 return mlir::success();
232 }
233};
234
239template <typename... AdditionalOpClasses>
240mlir::RewritePatternSet newGeneralRewritePatternSet(
241 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
242) {
243 mlir::RewritePatternSet patterns(ctx);
244 auto inserter = [&](auto... opClasses) {
245 patterns.add<GeneralTypeReplacePattern<decltype(opClasses)>...>(tyConv, ctx);
246 };
247 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
248 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
249 // Special cases for ops where GeneralTypeReplacePattern doesn't work
251 // Add builtin FunctionType converter
252 mlir::populateFunctionOpInterfaceTypeConversionPattern<llzk::function::FuncDefOp>(
253 patterns, tyConv
254 );
255 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
256 return patterns;
257}
258
260mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx);
261
268template <typename... AdditionalOpClasses, typename... AdditionalChecks>
269mlir::ConversionTarget newConverterDefinedTarget(
270 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks
271) {
272 mlir::ConversionTarget target = newBaseTarget(ctx);
273 auto inserter = [&](auto... opClasses) {
274 target.addDynamicallyLegalOp<decltype(opClasses)...>([&tyConv,
275 &checks...](mlir::Operation *op) {
276 return defaultLegalityCheck(tyConv, op) && (runCheck<AdditionalChecks>(op, checks) && ...);
277 });
278 };
279 std::apply(inserter, OpClassesWithStructTypes.NoGeneralBuilder);
280 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
281 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
282 return target;
283}
284
285} // namespace llzk::polymorphic::detail
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::Operation::operand_range getElements()
Definition Ops.h.inc:388
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:272
mlir::LogicalResult matchAndRewrite(llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:219
mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override
Definition SharedImpl.h:186
void rewrite(llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:194
mlir::LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:141
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:240
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:269
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:126
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
bool isNullOrEmpty(mlir::ArrayAttr a)