LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
Types.cpp
Go to the documentation of this file.
1//===-- Types.cpp - Array type implementations ------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
14
15using namespace mlir;
16
17namespace llzk::array {
18
20 MLIRContext *ctx, ArrayRef<int64_t> shape, SmallVector<Attribute> &dimensionSizes
21) {
22 Builder builder(ctx);
23 dimensionSizes = llvm::map_to_vector(shape, [&builder](int64_t v) -> Attribute {
24 return builder.getIndexAttr(v);
25 });
26 assert(dimensionSizes.size() == shape.size()); // fully computed by this function
27 return success();
28}
29
31 EmitErrorFn emitError, MLIRContext *ctx, ArrayRef<Attribute> dimensionSizes,
32 SmallVector<int64_t> &shape
33) {
34 assert(shape.empty()); // fully computed by this function
35
36 // Ensure all Attributes are valid Attribute classes for ArrayType.
37 // In the case where `emitError==null`, we mirror how the verification failure is handled by
38 // `*Type::get()` via `StorageUserBase` (i.e. use DefaultDiagnosticEmitFn and assert). See:
39 // https://github.com/llvm/llvm-project/blob/0897373f1a329a7a02f8ce3c501a05d2f9c89390/mlir/include/mlir/IR/StorageUniquerSupport.h#L179-L180
40 auto errFunc = emitError ? llvm::unique_function<InFlightDiagnostic()>(emitError)
41 : mlir::detail::getDefaultDiagnosticEmitFn(ctx);
42 if (verifyArrayDimSizes(errFunc, dimensionSizes).failed()) {
43 assert(emitError);
44 return failure();
45 }
46
47 // Convert the Attributes to int64_t
48 for (Attribute a : dimensionSizes) {
49 if (auto p = llvm::dyn_cast_if_present<IntegerAttr>(a)) {
50 shape.push_back(fromAPInt(p.getValue()));
51 } else if (llvm::isa_and_present<SymbolRefAttr, AffineMapAttr>(a)) {
52 // The ShapedTypeInterface uses 'kDynamic' for dimensions with non-static size.
53 shape.push_back(ShapedType::kDynamic);
54 } else {
55 // For every Attribute class in ArrayDimensionTypes, there should be a case here.
56 llvm::report_fatal_error("computeShapeFromDims() is out of sync with ArrayDimensionTypes");
57 return failure();
58 }
59 }
60 assert(shape.size() == dimensionSizes.size()); // fully computed by this function
61 return success();
62}
63
65 AsmParser &parser, SmallVector<int64_t> &shape, SmallVector<Attribute> dimensionSizes
66) {
67 // This is not actually parsing. It's computing the derived
68 // `shape` from the `dimensionSizes` attributes.
69 auto emitError = [&parser] { return parser.emitError(parser.getCurrentLocation()); };
70 return computeShapeFromDims(emitError, parser.getContext(), dimensionSizes, shape);
71}
72void printDerivedShape(AsmPrinter &, ArrayRef<int64_t>, ArrayRef<Attribute>) {
73 // nothing to print, it's derived and therefore not represented in the output
74}
75
76LogicalResult ArrayType::verify(
77 EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes,
78 ArrayRef<int64_t> shape
79) {
80 return verifyArrayType(emitError, elementType, dimensionSizes);
81}
82
83ArrayType ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const {
84 return ArrayType::get(elementType, shape.has_value() ? shape.value() : getShape());
85}
86
88ArrayType::cloneWith(Type elementType, std::optional<ArrayRef<Attribute>> dimensions) const {
89 return ArrayType::get(
90 elementType, dimensions.has_value() ? dimensions.value() : getDimensionSizes()
91 );
92}
93
94namespace {
95
96inline ArrayType createArrayOfSizeOne(Type elemType) { return ArrayType::get(elemType, {1}); }
97
98} // namespace
99
100bool ArrayType::collectIndices(llvm::function_ref<void(ArrayAttr)> inserter) const {
101 if (!hasStaticShape()) {
102 return false;
103 }
104 MLIRContext *ctx = getContext();
105 ArrayIndexGen idxGen = ArrayIndexGen::from(*this);
106 for (int64_t e = getNumElements(), i = 0; i < e; ++i) {
107 auto delinearized = idxGen.delinearize(i, ctx);
108 assert(delinearized.has_value()); // cannot fail since loop is over array size
109 inserter(ArrayAttr::get(ctx, delinearized.value()));
110 }
111 return true;
112}
113
114std::optional<SmallVector<ArrayAttr>> ArrayType::getSubelementIndices() const {
115 SmallVector<ArrayAttr> ret;
116 bool success = collectIndices([&ret](ArrayAttr v) { ret.push_back(v); });
117 return success ? std::make_optional(ret) : std::nullopt;
118}
119
121std::optional<DenseMap<Attribute, Type>> ArrayType::getSubelementIndexMap() const {
122 DenseMap<Attribute, Type> ret;
123 Type destructAs = createArrayOfSizeOne(getElementType());
124 bool success = collectIndices([&](ArrayAttr v) { ret[v] = destructAs; });
125 return success ? std::make_optional(ret) : std::nullopt;
126}
127
129Type ArrayType::getTypeAtIndex(Attribute index) const {
130 if (!hasStaticShape()) {
131 return nullptr;
132 }
133 // Since indexing is multi-dimensional, `index` should be ArrayAttr
134 ArrayAttr indexAttr = llvm::dyn_cast<ArrayAttr>(index);
135 if (!indexAttr) {
136 return nullptr;
137 }
138 // Ensure the shape is valid and dimensions are valid for the shape by computing linear index.
139 if (!ArrayIndexGen::from(*this).linearize(indexAttr.getValue())) {
140 return nullptr;
141 }
142 // If that's successful, the destructured type is the size-1 array of the element type.
143 return createArrayOfSizeOne(getElementType());
144}
145
146ParseResult parseAttrVec(AsmParser &parser, SmallVector<Attribute> &value) {
147 SmallVector<Attribute> attrs;
148 auto parseElement = [&]() -> ParseResult {
149 auto qResult = parser.parseOptionalQuestion();
150 if (succeeded(qResult)) {
151 auto &builder = parser.getBuilder();
152 value.push_back(builder.getIntegerAttr(builder.getIndexType(), ShapedType::kDynamic));
153 return qResult;
154 }
155 auto attrParseResult = FieldParser<Attribute>::parse(parser);
156 if (succeeded(attrParseResult)) {
157 value.push_back(forceIntAttrType(*attrParseResult));
158 }
159 return ParseResult(attrParseResult);
160 };
161 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElement))) {
162 return parser.emitError(parser.getCurrentLocation(), "failed to parse array dimensions");
163 }
164 return success();
165}
166
167void printAttrVec(AsmPrinter &printer, ArrayRef<Attribute> value) {
168 printAttrs(printer, value, ",");
169}
170
171} // namespace llzk::array
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
::llvm::ArrayRef< int64_t > getShape() const
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:114
::mlir::Type getElementType() const
::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type > > getSubelementIndexMap() const
Required by DestructurableTypeInterface / SROA pass.
Definition Types.cpp:121
::mlir::Type getTypeAtIndex(::mlir::Attribute index) const
Required by DestructurableTypeInterface / SROA pass.
Definition Types.cpp:129
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes, ::llvm::ArrayRef< int64_t > shape)
Definition Types.cpp:76
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
ParseResult parseAttrVec(AsmParser &parser, SmallVector< Attribute > &value)
Definition Types.cpp:146
void printDerivedShape(AsmPrinter &, ArrayRef< int64_t >, ArrayRef< Attribute >)
Definition Types.cpp:72
LogicalResult computeDimsFromShape(MLIRContext *ctx, ArrayRef< int64_t > shape, SmallVector< Attribute > &dimensionSizes)
Definition Types.cpp:19
void printAttrVec(AsmPrinter &printer, ArrayRef< Attribute > value)
Definition Types.cpp:167
LogicalResult computeShapeFromDims(EmitErrorFn emitError, MLIRContext *ctx, ArrayRef< Attribute > dimensionSizes, SmallVector< int64_t > &shape)
Definition Types.cpp:30
ParseResult parseDerivedShape(AsmParser &parser, SmallVector< int64_t > &shape, SmallVector< Attribute > dimensionSizes)
Definition Types.cpp:64
llvm::function_ref< mlir::InFlightDiagnostic()> EmitErrorFn
Definition ErrorHelper.h:18
void printAttrs(AsmPrinter &printer, ArrayRef< Attribute > attrs, const StringRef &separator)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
Attribute forceIntAttrType(Attribute attr)
int64_t fromAPInt(llvm::APInt i)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)