LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ArrayTypeHelper.cpp
Go to the documentation of this file.
1//===-- ArrayTypeHelper.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
17
18#include <mlir/Dialect/Arith/IR/Arith.h>
19#include <mlir/Dialect/Utils/IndexingUtils.h>
20#include <mlir/IR/Matchers.h>
21
22#include <llvm/ADT/APInt.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/STLFunctionalExtras.h>
25
26using namespace mlir;
27using namespace llzk;
28using namespace llzk::array;
29
30ArrayIndexGen::ArrayIndexGen(ArrayType t)
31 : shape(t.getShape()), linearSize(t.getNumElements()), strides(mlir::computeStrides(shape)) {}
32
34 assert(t.hasStaticShape());
35 return ArrayIndexGen(t);
36}
37
38namespace {
39
40inline bool isInRange(int64_t idx, int64_t dimSize) { return 0 <= idx && idx < dimSize; }
41
42// This can support Value, Attribute, and Operation* per matchPattern() implementations.
43template <typename TypeOfIndex> inline std::optional<int64_t> toI64(TypeOfIndex index) {
44 llvm::APInt idxAP;
45 if (!mlir::matchPattern(index, mlir::m_ConstantInt(&idxAP))) {
46 return std::nullopt;
47 }
48 return llzk::fromAPInt(idxAP);
49}
50
51template <typename OutType> struct CheckAndConvert {
52 template <typename InType>
53 static std::optional<OutType> from(InType /*index*/, int64_t /*dimSize*/) {
54 static_assert(sizeof(OutType) == 0, "CheckAndConvert not implemented for requested type.");
55 assert(false);
56 }
57};
58
59// Specialization to produce `int64_t`
60template <> struct CheckAndConvert<int64_t> {
61 template <typename InType> static std::optional<int64_t> from(InType index, int64_t dimSize) {
62 if (auto idxVal = toI64<InType>(index)) {
63 if (isInRange(*idxVal, dimSize)) {
64 return *idxVal;
65 }
66 }
67 return std::nullopt;
68 }
69};
70
71// Specialization to produce `Attribute`
72template <> struct CheckAndConvert<Attribute> {
73 template <typename InType> static std::optional<Attribute> from(InType index, int64_t dimSize) {
74 if (auto c = CheckAndConvert<int64_t>::from(index, dimSize)) {
75 return IntegerAttr::get(IndexType::get(index.getContext()), *c);
76 }
77 return std::nullopt;
78 }
79};
80
81template <typename OutType, typename InListType>
82inline std::optional<SmallVector<OutType>>
83checkAndConvertMulti(InListType multiDimIndex, ArrayRef<int64_t> shape, bool mustBeEqual) {
84 if (mustBeEqual) {
85 assert(
86 llvm::all_equal({llvm::range_size(multiDimIndex), llvm::range_size(shape)}) &&
87 "Iteratees do not have equal length"
88 );
89 }
90 SmallVector<OutType> ret;
91 for (auto [idx, dimSize] : llvm::zip_first(multiDimIndex, shape)) {
92 std::optional<OutType> next = CheckAndConvert<OutType>::from(idx, dimSize);
93 if (!next.has_value()) {
94 return std::nullopt;
95 }
96 ret.push_back(next.value());
97 }
98 return ret;
99}
100
101inline std::optional<int64_t> linearizeImpl(
102 ArrayRef<int64_t> multiDimIndex, const ArrayRef<int64_t> &shape,
103 const SmallVector<int64_t> &strides
104) {
105 // Ensure the index for each dimension is in range. Then the linearized index will be as well.
106 for (auto [idx, dimSize] : llvm::zip_equal(multiDimIndex, shape)) {
107 if (!isInRange(idx, dimSize)) {
108 return std::nullopt;
109 }
110 }
111 return mlir::linearize(multiDimIndex, strides);
112}
113
114template <typename TypeOfIndex>
115inline std::optional<int64_t> linearizeImpl(
116 ArrayRef<TypeOfIndex> multiDimIndex, const ArrayRef<int64_t> &shape,
117 const SmallVector<int64_t> &strides
118) {
119 std::optional<SmallVector<int64_t>> conv =
120 checkAndConvertMulti<int64_t>(multiDimIndex, shape, true /*TODO: I think*/);
121 if (!conv.has_value()) {
122 return std::nullopt;
123 }
124 return mlir::linearize(conv.value(), strides);
125}
126
127template <typename ResultElemType>
128inline std::optional<SmallVector<ResultElemType>> delinearizeImpl(
129 int64_t linearIndex, int64_t linearSize, const SmallVector<int64_t> &strides, MLIRContext *ctx,
130 llvm::function_ref<ResultElemType(IntegerAttr)> convert
131) {
132 if (!isInRange(linearIndex, linearSize)) {
133 return std::nullopt;
134 }
135 SmallVector<ResultElemType> ret;
136 for (int64_t idx : mlir::delinearize(linearIndex, strides)) {
137 ret.push_back(convert(IntegerAttr::get(IndexType::get(ctx), idx)));
138 }
139 return ret;
140}
141
142} // namespace
143
144std::optional<SmallVector<Value>>
145ArrayIndexGen::delinearize(int64_t linearIndex, Location loc, OpBuilder &bldr) const {
146 return delinearizeImpl<Value>(
147 linearIndex, linearSize, strides, bldr.getContext(),
148 [&](IntegerAttr a) { return bldr.create<arith::ConstantOp>(loc, a); }
149 );
150}
151
152std::optional<SmallVector<Attribute>>
153ArrayIndexGen::delinearize(int64_t linearIndex, MLIRContext *ctx) const {
154 return delinearizeImpl<Attribute>(linearIndex, linearSize, strides, ctx, [](IntegerAttr a) {
155 return a;
156 });
157}
158
159template <typename InListType> std::optional<int64_t> ArrayIndexGen::linearize(InListType) const {
160 static_assert(sizeof(InListType) == 0, "linearize() not implemented for requested type.");
161 llvm_unreachable("must have concrete instantiation");
162 return std::nullopt;
163}
164
165template <> std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<int64_t> multiDimIndex) const {
166 return linearizeImpl(multiDimIndex, shape, strides);
167}
168
169template <>
170std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<Attribute> multiDimIndex) const {
171 return linearizeImpl(multiDimIndex, shape, strides);
172}
173
174template <>
175std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<Operation *> multiDimIndex) const {
176 return linearizeImpl(multiDimIndex, shape, strides);
177}
178
179template <> std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<Value> multiDimIndex) const {
180 return linearizeImpl(multiDimIndex, shape, strides);
181}
182
183template <typename InListType>
184std::optional<SmallVector<Attribute>> ArrayIndexGen::checkAndConvert(InListType) {
185 static_assert(sizeof(InListType) == 0, "checkAndConvert() not implemented for requested type.");
186 llvm_unreachable("must have concrete instantiation");
187 return std::nullopt;
188}
189
190template <>
191std::optional<SmallVector<Attribute>> ArrayIndexGen::checkAndConvert(OperandRange multiDimIndex) {
192 return checkAndConvertMulti<Attribute>(multiDimIndex, shape, false);
193}
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
std::optional< llvm::SmallVector< mlir::Attribute > > checkAndConvert(InListType multiDimIndex)
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< int64_t > linearize(InListType multiDimIndex) const
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
int64_t fromAPInt(const llvm::APInt &i)