LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Types.cpp
Go to the documentation of this file.
1//===-- Types.cpp - Array type implementations ------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
14
15#include <mlir/IR/BuiltinTypeInterfaces.h>
16
17using namespace mlir;
18
19namespace llzk::array {
20
22 MLIRContext *ctx, ArrayRef<int64_t> shape, SmallVector<Attribute> &dimensionSizes
23) {
24 Builder builder(ctx);
25 dimensionSizes = llvm::map_to_vector(shape, [&builder](int64_t v) -> Attribute {
26 return builder.getIndexAttr(v);
27 });
28 assert(dimensionSizes.size() == shape.size()); // fully computed by this function
29 return success();
30}
31
33 EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes, SmallVector<int64_t> &shape
34) {
35 assert(shape.empty()); // fully computed by this function
36
37 // Ensure all Attributes are valid Attribute classes for ArrayType.
38 if (failed(verifyArrayDimSizes(emitError, dimensionSizes))) {
39 return failure();
40 }
41
42 // Convert the Attributes to int64_t
43 for (Attribute a : dimensionSizes) {
44 if (auto p = llvm::dyn_cast_if_present<IntegerAttr>(a)) {
45 shape.push_back(fromAPInt(p.getValue()));
46 } else if (llvm::isa_and_present<SymbolRefAttr, AffineMapAttr>(a)) {
47 // The ShapedTypeInterface uses 'kDynamic' for dimensions with non-static size.
48 shape.push_back(ShapedType::kDynamic);
49 } else {
50 // For every Attribute class in ArrayDimensionTypes, there should be a case here.
51 llvm::report_fatal_error("computeShapeFromDims() is out of sync with ArrayDimensionTypes");
52 return failure();
53 }
54 }
55 assert(shape.size() == dimensionSizes.size()); // fully computed by this function
56 return success();
57}
58
60 AsmParser &parser, SmallVector<int64_t> &shape, SmallVector<Attribute> dimensionSizes
61) {
62 // This is not actually parsing. It's computing the derived
63 // `shape` from the `dimensionSizes` attributes.
64 auto emitError = [&parser] {
65 return InFlightDiagnosticWrapper(parser.emitError(parser.getCurrentLocation()));
66 };
67 return computeShapeFromDims(emitError, dimensionSizes, shape);
68}
69void printDerivedShape(AsmPrinter &, ArrayRef<int64_t>, ArrayRef<Attribute>) {
70 // nothing to print, it's derived and therefore not represented in the output
71}
72
73LogicalResult ArrayType::verify(
74 function_ref<InFlightDiagnostic()> emitError, Type elementType,
75 ArrayRef<Attribute> dimensionSizes, ArrayRef<int64_t> shape
76) {
77 return verifyArrayType(wrapNonNullableInFlightDiagnostic(emitError), elementType, dimensionSizes);
78}
79
80ArrayType ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const {
81 return ArrayType::get(elementType, shape.has_value() ? shape.value() : getShape());
82}
83
85ArrayType::cloneWith(Type elementType, std::optional<ArrayRef<Attribute>> dimensions) const {
86 return ArrayType::get(
87 elementType, dimensions.has_value() ? dimensions.value() : getDimensionSizes()
88 );
89}
90
91namespace {
92
93inline ArrayType createArrayOfSizeOne(Type elemType) { return ArrayType::get(elemType, {1}); }
94
95} // namespace
96
97bool ArrayType::collectIndices(llvm::function_ref<void(ArrayAttr)> inserter) const {
98 if (!hasStaticShape()) {
99 return false;
100 }
101 MLIRContext *ctx = getContext();
102 ArrayIndexGen idxGen = ArrayIndexGen::from(*this);
103 for (int64_t e = getNumElements(), i = 0; i < e; ++i) {
104 auto delinearized = idxGen.delinearize(i, ctx);
105 assert(delinearized.has_value()); // cannot fail since loop is over array size
106 inserter(ArrayAttr::get(ctx, delinearized.value()));
107 }
108 return true;
109}
110
111std::optional<SmallVector<ArrayAttr>> ArrayType::getSubelementIndices() const {
112 SmallVector<ArrayAttr> ret;
113 bool success = collectIndices([&ret](ArrayAttr v) { ret.push_back(v); });
114 return success ? std::make_optional(ret) : std::nullopt;
115}
116
118std::optional<DenseMap<Attribute, Type>> ArrayType::getSubelementIndexMap() const {
119 DenseMap<Attribute, Type> ret;
120 Type destructAs = createArrayOfSizeOne(getElementType());
121 bool success = collectIndices([&](ArrayAttr v) { ret[v] = destructAs; });
122 return success ? std::make_optional(ret) : std::nullopt;
123}
124
126Type ArrayType::getTypeAtIndex(Attribute index) const {
127 if (!hasStaticShape()) {
128 return nullptr;
129 }
130 // Since indexing is multi-dimensional, `index` should be ArrayAttr
131 ArrayAttr indexAttr = llvm::dyn_cast<ArrayAttr>(index);
132 if (!indexAttr) {
133 return nullptr;
134 }
135 // Ensure the shape is valid and dimensions are valid for the shape by computing linear index.
136 if (!ArrayIndexGen::from(*this).linearize(indexAttr.getValue())) {
137 return nullptr;
138 }
139 // If that's successful, the destructured type is the size-1 array of the element type.
140 return createArrayOfSizeOne(getElementType());
141}
142
143ParseResult parseAttrVec(AsmParser &parser, SmallVector<Attribute> &value) {
144 SmallVector<Attribute> attrs;
145 auto parseElement = [&parser, &value]() -> ParseResult {
146 auto qResult = parser.parseOptionalQuestion();
147 if (succeeded(qResult)) {
148 auto &builder = parser.getBuilder();
149 value.push_back(builder.getIntegerAttr(builder.getIndexType(), ShapedType::kDynamic));
150 return qResult;
151 }
152 auto attrParseResult = FieldParser<Attribute>::parse(parser);
153 if (succeeded(attrParseResult)) {
154 auto emitError = [&parser] {
155 return InFlightDiagnosticWrapper(parser.emitError(parser.getCurrentLocation()));
156 };
157 FailureOr<Attribute> forced = forceIntAttrType(*attrParseResult, emitError);
158 if (failed(forced)) {
159 return failure();
160 }
161 value.push_back(*forced);
162 }
163 return ParseResult(attrParseResult);
164 };
165 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElement))) {
166 return parser.emitError(parser.getCurrentLocation(), "failed to parse array dimensions");
167 }
168 return success();
169}
170
171void printAttrVec(AsmPrinter &printer, ArrayRef<Attribute> value) {
172 printAttrs(printer, value, ",");
173}
174
175} // namespace llzk::array
Wrapper around InFlightDiagnostic that can either be a regular InFlightDiagnostic or a special versio...
Definition ErrorHelper.h:26
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
::llvm::ArrayRef< int64_t > getShape() const
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:111
::mlir::Type getElementType() const
::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type > > getSubelementIndexMap() const
Required by DestructurableTypeInterface / SROA pass.
Definition Types.cpp:118
::mlir::Type getTypeAtIndex(::mlir::Attribute index) const
Required by DestructurableTypeInterface / SROA pass.
Definition Types.cpp:126
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes, ::llvm::ArrayRef< int64_t > shape)
Definition Types.cpp:73
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
LogicalResult computeShapeFromDims(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes, SmallVector< int64_t > &shape)
Definition Types.cpp:32
ParseResult parseAttrVec(AsmParser &parser, SmallVector< Attribute > &value)
Definition Types.cpp:143
void printDerivedShape(AsmPrinter &, ArrayRef< int64_t >, ArrayRef< Attribute >)
Definition Types.cpp:69
LogicalResult computeDimsFromShape(MLIRContext *ctx, ArrayRef< int64_t > shape, SmallVector< Attribute > &dimensionSizes)
Definition Types.cpp:21
void printAttrVec(AsmPrinter &printer, ArrayRef< Attribute > value)
Definition Types.cpp:171
ParseResult parseDerivedShape(AsmParser &parser, SmallVector< int64_t > &shape, SmallVector< Attribute > dimensionSizes)
Definition Types.cpp:59
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
void printAttrs(AsmPrinter &printer, ArrayRef< Attribute > attrs, const StringRef &separator)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
OwningEmitErrorFn wrapNonNullableInFlightDiagnostic(llvm::function_ref< mlir::InFlightDiagnostic()> emitError)
int64_t fromAPInt(const llvm::APInt &i)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)