LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Function.cpp
Go to the documentation of this file.
1//===-- Function.cpp - Function dialect C API implementation ----*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "llzk/CAPI/Builder.h"
11#include "llzk/CAPI/Support.h"
14
16
17#include <mlir/CAPI/IR.h>
18#include <mlir/CAPI/Pass.h>
19#include <mlir/CAPI/Registration.h>
20#include <mlir/CAPI/Wrap.h>
21#include <mlir/IR/Attributes.h>
22#include <mlir/IR/BuiltinAttributes.h>
23
24#include <mlir-c/IR.h>
25#include <mlir-c/Pass.h>
26
27#include <llvm/ADT/SmallVectorExtras.h>
28
29using namespace llzk::function;
30using namespace mlir;
31using namespace llzk;
32
34
35static NamedAttribute unwrap(MlirNamedAttribute attr) {
36 return NamedAttribute(unwrap(attr.name), unwrap(attr.attribute));
37}
38
39//===----------------------------------------------------------------------===//
40// FuncDefOp
41//===----------------------------------------------------------------------===//
42
46 MlirLocation location, MlirStringRef name, MlirType funcType, intptr_t numAttrs,
47 MlirNamedAttribute const *attrs, intptr_t numArgAttrs, MlirAttribute const *argAttrs
48) {
49 SmallVector<NamedAttribute> attrsSto;
50 SmallVector<Attribute> argAttrsSto;
51 SmallVector<DictionaryAttr> unwrappedArgAttrs =
52 llvm::map_to_vector(unwrapList(numArgAttrs, argAttrs, argAttrsSto), [](auto attr) {
53 return mlir::cast<DictionaryAttr>(attr);
54 });
55 return wrap(FuncDefOp::create(
56 unwrap(location), unwrap(name), mlir::cast<FunctionType>(unwrap(funcType)),
57 unwrapList(numAttrs, attrs, attrsSto), unwrappedArgAttrs
58 ));
59}
60
61bool llzkOperationIsAFuncDefOp(MlirOperation op) { return mlir::isa<FuncDefOp>(unwrap(op)); }
62
64 return mlir::unwrap_cast<FuncDefOp>(op).hasAllowConstraintAttr();
65}
66
67void llzkFuncDefOpSetAllowConstraintAttr(MlirOperation op, bool value) {
68 mlir::unwrap_cast<FuncDefOp>(op).setAllowConstraintAttr(value);
69}
70
72 return mlir::unwrap_cast<FuncDefOp>(op).hasAllowWitnessAttr();
73}
74
75void llzkFuncDefOpSetAllowWitnessAttr(MlirOperation op, bool value) {
76 mlir::unwrap_cast<FuncDefOp>(op).setAllowWitnessAttr(value);
77}
78
79bool llzkFuncDefOpGetHasArgIsPub(MlirOperation op, unsigned argNo) {
80 return mlir::unwrap_cast<FuncDefOp>(op).hasArgPublicAttr(argNo);
81}
82
83MlirAttribute llzkFuncDefOpGetFullyQualifiedName(MlirOperation op) {
84 return wrap(mlir::unwrap_cast<FuncDefOp>(op).getFullyQualifiedName());
85}
86
87bool llzkFuncDefOpGetNameIsCompute(MlirOperation op) {
88 return mlir::unwrap_cast<FuncDefOp>(op).nameIsCompute();
89}
90
91bool llzkFuncDefOpGetNameIsConstrain(MlirOperation op) {
92 return mlir::unwrap_cast<FuncDefOp>(op).nameIsConstrain();
93}
94
95bool llzkFuncDefOpGetIsInStruct(MlirOperation op) {
96 return mlir::unwrap_cast<FuncDefOp>(op).isInStruct();
97}
98
99bool llzkFuncDefOpGetIsStructCompute(MlirOperation op) {
100 return mlir::unwrap_cast<FuncDefOp>(op).isStructCompute();
101}
102
103bool llzkFuncDefOpGetIsStructConstrain(MlirOperation op) {
104 return mlir::unwrap_cast<FuncDefOp>(op).isStructConstrain();
105}
106
109 return wrap(mlir::unwrap_cast<FuncDefOp>(op).getSingleResultTypeOfCompute());
110}
111
112//===----------------------------------------------------------------------===//
113// CallOp
114//===----------------------------------------------------------------------===//
115
116static auto unwrapCallee(MlirOperation op) { return mlir::cast<FuncDefOp>(unwrap(op)); }
117
118static auto unwrapDims(MlirAttribute attr) { return mlir::cast<DenseI32ArrayAttr>(unwrap(attr)); }
119
120static auto unwrapName(MlirAttribute attr) { return mlir::cast<SymbolRefAttr>(unwrap(attr)); }
121
123 CallOp, intptr_t numResults, MlirType const *results, MlirAttribute name, intptr_t numOperands,
124 MlirValue const *operands
125) {
126 SmallVector<Type> resultsSto;
127 SmallVector<Value> operandsSto;
128 return wrap(create<CallOp>(
129 builder, location, unwrapList(numResults, results, resultsSto), unwrapName(name),
130 unwrapList(numOperands, operands, operandsSto)
131 ));
132}
133
135 CallOp, ToCallee, MlirOperation callee, intptr_t numOperands, MlirValue const *operands
136) {
137 SmallVector<Value> operandsSto;
138 return wrap(create<CallOp>(
139 builder, location, unwrapCallee(callee), unwrapList(numOperands, operands, operandsSto)
140 ));
141}
142
144 CallOp, WithMapOperands, intptr_t numResults, MlirType const *results, MlirAttribute name,
145 intptr_t numMapOperands, MlirValueRange const *mapOperands, MlirAttribute numDimsPerMap,
146 intptr_t numArgOperands, MlirValue const *argOperands
147) {
148 SmallVector<Type> resultsSto;
149 SmallVector<Value> argOperandsSto;
150 MapOperandsHelper<> mapOperandsHelper(numMapOperands, mapOperands);
151 return wrap(create<CallOp>(
152 builder, location, unwrapList(numResults, results, resultsSto), unwrapName(name),
153 *mapOperandsHelper, unwrapDims(numDimsPerMap),
154 unwrapList(numArgOperands, argOperands, argOperandsSto)
155 ));
156}
157
159 CallOp, WithMapOperandsAndDims, intptr_t numResults, MlirType const *results,
160 MlirAttribute name, intptr_t numMapOperands, MlirValueRange const *mapOperands,
161 intptr_t numDimsPermMapLength, int32_t const *numDimsPerMap, intptr_t numArgOperands,
162 MlirValue const *argOperands
163) {
164 SmallVector<Type> resultsSto;
165 SmallVector<Value> argOperandsSto;
166 MapOperandsHelper<> mapOperandsHelper(numMapOperands, mapOperands);
167 return wrap(create<CallOp>(
168 builder, location, unwrapList(numResults, results, resultsSto), unwrapName(name),
169 *mapOperandsHelper, ArrayRef(numDimsPerMap, numDimsPermMapLength),
170 unwrapList(numArgOperands, argOperands, argOperandsSto)
171 ));
172}
173
175 CallOp, ToCalleeWithMapOperands, MlirOperation callee, intptr_t numMapOperands,
176 MlirValueRange const *mapOperands, MlirAttribute numDimsPerMap, intptr_t numArgOperands,
177 MlirValue const *argOperands
178) {
179 SmallVector<Value> argOperandsSto;
180 MapOperandsHelper<> mapOperandsHelper(numMapOperands, mapOperands);
181 return wrap(create<CallOp>(
182 builder, location, unwrapCallee(callee), *mapOperandsHelper, unwrapDims(numDimsPerMap),
183 unwrapList(numArgOperands, argOperands, argOperandsSto)
184 ));
185}
186
188 CallOp, ToCalleeWithMapOperandsAndDims, MlirOperation callee, intptr_t numMapOperands,
189 MlirValueRange const *mapOperands, intptr_t numDimsPermMapLength, int32_t const *numDimsPerMap,
190 intptr_t numArgOperands, MlirValue const *argOperands
191) {
192 SmallVector<Value> argOperandsSto;
193 MapOperandsHelper<> mapOperandsHelper(numMapOperands, mapOperands);
194 return wrap(create<CallOp>(
195 builder, location, unwrapCallee(callee), *mapOperandsHelper,
196 ArrayRef(numDimsPerMap, numDimsPermMapLength),
197 unwrapList(numArgOperands, argOperands, argOperandsSto)
198 ));
199}
200
201bool llzkOperationIsACallOp(MlirOperation op) { return mlir::isa<CallOp>(unwrap(op)); }
202
203MlirType llzkCallOpGetCalleeType(MlirOperation op) {
204 return wrap(mlir::unwrap_cast<CallOp>(op).getCalleeType());
205}
206
207bool llzkCallOpGetCalleeIsCompute(MlirOperation op) {
208 return mlir::unwrap_cast<CallOp>(op).calleeIsCompute();
209}
210
211bool llzkCallOpGetCalleeIsConstrain(MlirOperation op) {
212 return mlir::unwrap_cast<CallOp>(op).calleeIsConstrain();
213}
214
216 return mlir::unwrap_cast<CallOp>(op).calleeIsStructCompute();
217}
218
220 return mlir::unwrap_cast<CallOp>(op).calleeIsStructConstrain();
221}
222
223MlirType llzkCallOpGetSingleResultTypeOfCompute(MlirOperation op) {
224 return wrap(mlir::unwrap_cast<CallOp>(op).getSingleResultTypeOfCompute());
225}
bool llzkFuncDefOpGetHasAllowWitnessAttr(MlirOperation op)
Definition Function.cpp:71
MlirType llzkFuncDefOpGetSingleResultTypeOfCompute(MlirOperation op)
Assuming the function is the compute function returns its StructType result.
Definition Function.cpp:108
MlirType llzkCallOpGetCalleeType(MlirOperation op)
Returns the FunctionType of the callee.
Definition Function.cpp:203
MlirOperation llzkFuncDefOpCreateWithAttrsAndArgAttrs(MlirLocation location, MlirStringRef name, MlirType funcType, intptr_t numAttrs, MlirNamedAttribute const *attrs, intptr_t numArgAttrs, MlirAttribute const *argAttrs)
Creates a FuncDefOp with the given attributes and argument attributes.
Definition Function.cpp:45
bool llzkFuncDefOpGetNameIsConstrain(MlirOperation op)
Definition Function.cpp:91
bool llzkCallOpGetCalleeIsStructConstrain(MlirOperation op)
Definition Function.cpp:219
bool llzkFuncDefOpGetIsInStruct(MlirOperation op)
Definition Function.cpp:95
bool llzkFuncDefOpGetHasArgIsPub(MlirOperation op, unsigned argNo)
Definition Function.cpp:79
bool llzkFuncDefOpGetIsStructCompute(MlirOperation op)
Definition Function.cpp:99
bool llzkCallOpGetCalleeIsConstrain(MlirOperation op)
Definition Function.cpp:211
void llzkFuncDefOpSetAllowConstraintAttr(MlirOperation op, bool value)
Sets the allow_constraint attribute in the FuncDefOp operation.
Definition Function.cpp:67
bool llzkCallOpGetCalleeIsStructCompute(MlirOperation op)
Definition Function.cpp:215
MlirType llzkCallOpGetSingleResultTypeOfCompute(MlirOperation op)
Assuming the callee is the compute function, returns its StructType result.
Definition Function.cpp:223
bool llzkFuncDefOpGetIsStructConstrain(MlirOperation op)
Definition Function.cpp:103
bool llzkOperationIsAFuncDefOp(MlirOperation op)
Definition Function.cpp:61
bool llzkCallOpGetCalleeIsCompute(MlirOperation op)
Definition Function.cpp:207
bool llzkOperationIsACallOp(MlirOperation op)
Definition Function.cpp:201
bool llzkFuncDefOpGetNameIsCompute(MlirOperation op)
Definition Function.cpp:87
bool llzkFuncDefOpGetHasAllowConstraintAttr(MlirOperation op)
Definition Function.cpp:63
MlirAttribute llzkFuncDefOpGetFullyQualifiedName(MlirOperation op)
Returns the fully qualified name of the function.
Definition Function.cpp:83
void llzkFuncDefOpSetAllowWitnessAttr(MlirOperation op, bool value)
Sets the allow_witness attribute in the FuncDefOp operation.
Definition Function.cpp:75
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Polymorphic, llzk__polymorphic, llzk::polymorphic::PolymorphicDialect) MlirType llzkTypeVarTypeGet(MlirContext ctx
MlirStringRef name
Helper for unwrapping the C arguments for the map operands.
Definition Support.h:36
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
#define LLZK_DEFINE_OP_BUILD_METHOD(op,...)
Definition Support.h:27
#define LLZK_DEFINE_SUFFIX_OP_BUILD_METHOD(op, suffix,...)
Definition Support.h:25
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
auto unwrap_cast(auto &from)
Definition Support.h:30