LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKConversionUtils.h
Go to the documentation of this file.
1//===-- LLZKConversionUtils.h -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Shared utilities for dialect converting transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLZK_TRANSFORMS_CONVERSION_UTILS_H
15#define LLZK_TRANSFORMS_CONVERSION_UTILS_H
16
18
19#include <mlir/IR/PatternMatch.h>
20
21namespace llzk {
22
26
27protected:
28 virtual llvm::SmallVector<mlir::Type> convertInputs(mlir::ArrayRef<mlir::Type> origTypes) = 0;
29 virtual llvm::SmallVector<mlir::Type> convertResults(mlir::ArrayRef<mlir::Type> origTypes) = 0;
30
31 virtual mlir::ArrayAttr
32 convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
33 virtual mlir::ArrayAttr
34 convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
35
36 virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter) = 0;
37
38public:
39 virtual ~FunctionTypeConverter() = default;
40
41 void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter) {
42 // Update in/out types of the function
43 mlir::FunctionType oldTy = op.getFunctionType();
44 llvm::SmallVector<mlir::Type> newInputs = convertInputs(oldTy.getInputs());
45 llvm::SmallVector<mlir::Type> newResults = convertResults(oldTy.getResults());
46 mlir::FunctionType newTy = mlir::FunctionType::get(
47 oldTy.getContext(), mlir::TypeRange(newInputs), mlir::TypeRange(newResults)
48 );
49 if (newTy == oldTy) {
50 return; // nothing to change
51 }
52
53 // Pre-condition: arg/result count equals corresponding attribute count
54 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
55 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
56 rewriter.modifyOpInPlace(op, [&]() {
57 op.setFunctionType(newTy);
58
59 // If any input or result types were added, ensure the attributes are updated too.
60 if (mlir::ArrayAttr newArgAttrs = convertInputAttrs(op.getArgAttrsAttr(), newInputs)) {
61 op.setArgAttrsAttr(newArgAttrs);
62 }
63 if (mlir::ArrayAttr newResAttrs = convertResultAttrs(op.getResAttrsAttr(), newResults)) {
64 op.setResAttrsAttr(newResAttrs);
65 }
66 });
67 // Post-condition: arg/result count equals corresponding attribute count
68 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
69 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
70
71 // If the function has a body, ensure the entry block arguments match the function inputs.
72 if (mlir::Region *body = op.getCallableRegion()) {
73 mlir::Block &entryBlock = body->front();
74 if (!std::cmp_equal(entryBlock.getNumArguments(), newInputs.size())) {
75 processBlockArgs(entryBlock, rewriter);
76 // Post-condition: block args must match function inputs
77 assert(std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()));
78 }
79 }
80 }
81};
82
83} // namespace llzk
84
85#endif // LLZK_TRANSFORMS_CONVERSION_UTILS_H
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter)
virtual ~FunctionTypeConverter()=default
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
void setArgAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:650
void setResAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:654
::mlir::ArrayAttr getArgAttrsAttr()
Definition Ops.h.inc:630
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:971
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
Definition Ops.h.inc:746
::mlir::ArrayAttr getResAttrsAttr()
Definition Ops.h.inc:635