LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKConversionUtils.h
Go to the documentation of this file.
1//===-- LLZKConversionUtils.h -----------------------------------*- C++ -*-===//
2//
3// Shared utilities for dialect converting transformations.
4//
5//===----------------------------------------------------------------------===//
6
7#ifndef LLZK_TRANSFORMS_CONVERSION_UTILS_H
8#define LLZK_TRANSFORMS_CONVERSION_UTILS_H
9
11
12#include <mlir/IR/PatternMatch.h>
13
14namespace llzk {
15
19
20protected:
21 virtual llvm::SmallVector<mlir::Type> convertInputs(mlir::ArrayRef<mlir::Type> origTypes) = 0;
22 virtual llvm::SmallVector<mlir::Type> convertResults(mlir::ArrayRef<mlir::Type> origTypes) = 0;
23
24 virtual mlir::ArrayAttr
25 convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
26 virtual mlir::ArrayAttr
27 convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
28
29 virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter) = 0;
30
31public:
32 virtual ~FunctionTypeConverter() = default;
33
34 void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter) {
35 // Update in/out types of the function
36 mlir::FunctionType oldTy = op.getFunctionType();
37 llvm::SmallVector<mlir::Type> newInputs = convertInputs(oldTy.getInputs());
38 llvm::SmallVector<mlir::Type> newResults = convertResults(oldTy.getResults());
39 mlir::FunctionType newTy = mlir::FunctionType::get(
40 oldTy.getContext(), mlir::TypeRange(newInputs), mlir::TypeRange(newResults)
41 );
42 if (newTy == oldTy) {
43 return; // nothing to change
44 }
45
46 // Pre-condition: arg/result count equals corresponding attribute count
47 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
48 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
49 rewriter.modifyOpInPlace(op, [&]() {
50 op.setFunctionType(newTy);
51
52 // If any input or result types were added, ensure the attributes are updated too.
53 if (mlir::ArrayAttr newArgAttrs = convertInputAttrs(op.getArgAttrsAttr(), newInputs)) {
54 op.setArgAttrsAttr(newArgAttrs);
55 }
56 if (mlir::ArrayAttr newResAttrs = convertResultAttrs(op.getResAttrsAttr(), newResults)) {
57 op.setResAttrsAttr(newResAttrs);
58 }
59 });
60 // Post-condition: arg/result count equals corresponding attribute count
61 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
62 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
63
64 // If the function has a body, ensure the entry block arguments match the function inputs.
65 if (mlir::Region *body = op.getCallableRegion()) {
66 mlir::Block &entryBlock = body->front();
67 if (!std::cmp_equal(entryBlock.getNumArguments(), newInputs.size())) {
68 processBlockArgs(entryBlock, rewriter);
69 // Post-condition: block args must match function inputs
70 assert(std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()));
71 }
72 }
73 }
74};
75
76} // namespace llzk
77
78#endif // LLZK_TRANSFORMS_CONVERSION_UTILS_H
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter)
virtual ~FunctionTypeConverter()=default
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
::mlir::ArrayAttr getArgAttrsAttr()
Definition Ops.cpp.inc:1100
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
void setArgAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.cpp.inc:1134
void setResAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.cpp.inc:1138
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1130
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
Definition Ops.h.inc:604
::mlir::ArrayAttr getResAttrsAttr()
Definition Ops.cpp.inc:1109