LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
SharedImpl.h
Go to the documentation of this file.
1//===-- SharedImpl.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
15#pragma once
16
24
25#include <mlir/Dialect/Arith/IR/Arith.h>
26#include <mlir/Dialect/SCF/IR/SCF.h>
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
34
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/ADT/SmallVector.h>
37
38#include <tuple>
39
41
42namespace {
43
45static struct {
46
49 const std::tuple<
50 // clang-format off
69 // clang-format on
70 >
71 WithGeneralBuilder {};
72
77 const std::tuple<llzk::function::CallOp, llzk::array::CreateArrayOp> NoGeneralBuilder {};
78
79} OpClassesWithStructTypes;
80
81template <typename I, typename NextOpClass, typename... OtherOpClasses>
82inline void applyToMoreTypes(I inserter) {
83 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
84}
85template <typename I> inline void applyToMoreTypes(I inserter) {}
86
87inline bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op) {
88 // Check operand types and result types
89 if (!tyConv.isLegal(op)) {
90 return false;
91 }
92 // Check type attributes
93 // Extend lifetime of temporary to suppress warnings.
94 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
95 for (mlir::NamedAttribute n : dictAttr.getValue()) {
96 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
97 mlir::Type t = tyAttr.getValue();
98 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
99 if (!tyConv.isSignatureLegal(funcTy)) {
100 return false;
101 }
102 } else {
103 if (!tyConv.isLegal(t)) {
104 return false;
105 }
106 }
107 }
108 }
109 return true;
110}
111
112// Default to true if the check is not for that particular operation type.
113template <typename Check> inline bool runCheck(mlir::Operation *op, Check check) {
114 if (auto specificOp =
115 llvm::dyn_cast_if_present<typename llvm::function_traits<Check>::template arg_t<0>>(op)) {
116 return check(specificOp);
117 }
118 return true;
119}
120
121} // namespace
122
125template <typename OpClass, typename Rewriter, typename... Args>
126inline OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args) {
127 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
128 OpClass newOp = rewriter.template replaceOpWithNewOp<OpClass>(op, std::forward<Args>(args)...);
129 newOp->setDiscardableAttrs(attrs);
130 return newOp;
131}
132
133// NOTE: This pattern will produce a compile error if `OpClass` does not define the general
134// `build(OpBuilder&, OperationState&, TypeRange, ValueRange, ArrayRef<NamedAttribute>)` function
135// because that function is required by the `replaceOpWithNewOp()` call.
136template <typename OpClass>
137class GeneralTypeReplacePattern : public mlir::OpConversionPattern<OpClass> {
138public:
139 using mlir::OpConversionPattern<OpClass>::OpConversionPattern;
140
141 mlir::LogicalResult matchAndRewrite(
142 OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
143 ) const override {
144 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
145 assert(converter);
146 // Convert result types
147 mlir::SmallVector<mlir::Type> newResultTypes;
148 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
149 return op->emitError("Could not convert Op result types.");
150 }
151 // ASSERT: 'adaptor.getAttributes()' is empty or subset of 'op->getAttrDictionary()' so the
152 // former can be ignored without losing anything.
153 assert(
154 adaptor.getAttributes().empty() ||
155 llvm::all_of(
156 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a
157 ) { return d.contains(a.getName()); }
158 )
159 );
160 // Convert any TypeAttr in the attribute list.
161 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
162 for (mlir::NamedAttribute &n : newAttrs) {
163 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
164 if (mlir::Type newType = converter->convertType(t.getValue())) {
165 n.setValue(mlir::TypeAttr::get(newType));
166 } else {
167 return op->emitError().append("Could not convert type in attribute: ", t);
168 }
169 }
170 }
171 // Build a new Op in place of the current one
173 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
174 mlir::ArrayRef(newAttrs)
175 );
176 return mlir::success();
177 }
178};
179
181 : public mlir::OpConversionPattern<llzk::array::CreateArrayOp> {
182public:
183 using mlir::OpConversionPattern<llzk::array::CreateArrayOp>::OpConversionPattern;
184
185 mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override {
186 if (mlir::Type newType = getTypeConverter()->convertType(op.getType())) {
187 return mlir::success();
188 } else {
189 return op->emitError("Could not convert Op result type.");
190 }
191 }
192
194 llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
195 ) const override {
196 mlir::Type newType = getTypeConverter()->convertType(op.getType());
197 assert(
198 llvm::isa<llzk::array::ArrayType>(newType) && "CreateArrayOp must produce ArrayType result"
199 );
200 mlir::DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr();
201 if (isNullOrEmpty(numDimsPerMap)) {
203 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.getElements()
204 );
205 } else {
207 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.getMapOperands(),
208 numDimsPerMap
209 );
210 }
211 }
212};
213
214class CallOpClassReplacePattern : public mlir::OpConversionPattern<llzk::function::CallOp> {
215public:
216 using mlir::OpConversionPattern<llzk::function::CallOp>::OpConversionPattern;
217
218 mlir::LogicalResult matchAndRewrite(
219 llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
220 ) const override {
221 // Convert the result types of the CallOp
222 mlir::SmallVector<mlir::Type> newResultTypes;
223 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
224 return op->emitError("Could not convert Op result types.");
225 }
227 rewriter, op, newResultTypes, op.getCalleeAttr(), adapter.getMapOperands(),
228 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
229 );
230 return mlir::success();
231 }
232};
233
238template <typename... AdditionalOpClasses>
239mlir::RewritePatternSet newGeneralRewritePatternSet(
240 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
241) {
242 mlir::RewritePatternSet patterns(ctx);
243 auto inserter = [&](auto... opClasses) {
244 patterns.add<GeneralTypeReplacePattern<decltype(opClasses)>...>(tyConv, ctx);
245 };
246 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
247 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
248 // Special cases for ops where GeneralTypeReplacePattern doesn't work
250 // Add builtin FunctionType converter
251 mlir::populateFunctionOpInterfaceTypeConversionPattern<llzk::function::FuncDefOp>(
252 patterns, tyConv
253 );
254 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
255 return patterns;
256}
257
259mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx);
260
267template <typename... AdditionalOpClasses, typename... AdditionalChecks>
268mlir::ConversionTarget newConverterDefinedTarget(
269 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks
270) {
271 mlir::ConversionTarget target = newBaseTarget(ctx);
272 auto inserter = [&](auto... opClasses) {
273 target.addDynamicallyLegalOp<decltype(opClasses)...>([&tyConv,
274 &checks...](mlir::Operation *op) {
275 return defaultLegalityCheck(tyConv, op) && (runCheck<AdditionalChecks>(op, checks) && ...);
276 });
277 };
278 std::apply(inserter, OpClassesWithStructTypes.NoGeneralBuilder);
279 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
280 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
281 return target;
282}
283
284} // namespace llzk::polymorphic::detail
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:649
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:412
::mlir::Operation::operand_range getElements()
Definition Ops.cpp.inc:408
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:531
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.cpp.inc:522
mlir::LogicalResult matchAndRewrite(llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:218
mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override
Definition SharedImpl.h:185
void rewrite(llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:193
mlir::LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:141
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:239
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:268
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:126
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
bool isNullOrEmpty(mlir::ArrayAttr a)