18#include <mlir/Dialect/Arith/IR/Arith.h>
19#include <mlir/Dialect/Utils/IndexingUtils.h>
20#include <mlir/IR/Matchers.h>
22#include <llvm/ADT/APInt.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/STLFunctionalExtras.h>
31 : shape(t.getShape()), linearSize(t.getNumElements()), strides(
mlir::computeStrides(shape)) {}
34 assert(t.hasStaticShape());
35 return ArrayIndexGen(t);
40inline bool isInRange(int64_t idx, int64_t dimSize) {
return 0 <= idx && idx < dimSize; }
43template <
typename TypeOfIndex>
inline std::optional<int64_t> toI64(TypeOfIndex index) {
45 if (!mlir::matchPattern(index, mlir::m_ConstantInt(&idxAP))) {
51template <
typename OutType>
struct CheckAndConvert {
52 template <
typename InType>
static std::optional<OutType>
from(InType index, int64_t dimSize) {
53 static_assert(
sizeof(OutType) == 0,
"CheckAndConvert not implemented for requested type.");
59template <>
struct CheckAndConvert<int64_t> {
60 template <
typename InType>
static std::optional<int64_t>
from(InType index, int64_t dimSize) {
61 if (
auto idxVal = toI64<InType>(index)) {
62 if (isInRange(*idxVal, dimSize)) {
71template <>
struct CheckAndConvert<Attribute> {
72 template <
typename InType>
static std::optional<Attribute>
from(InType index, int64_t dimSize) {
73 if (
auto c = CheckAndConvert<int64_t>::from(index, dimSize)) {
74 return IntegerAttr::get(IndexType::get(index.getContext()), *c);
80template <
typename OutType,
typename InListType>
81inline std::optional<SmallVector<OutType>>
82checkAndConvertMulti(InListType multiDimIndex, ArrayRef<int64_t> shape,
bool mustBeEqual) {
85 llvm::all_equal({llvm::range_size(multiDimIndex), llvm::range_size(shape)}) &&
86 "Iteratees do not have equal length"
89 SmallVector<OutType> ret;
90 for (
auto [idx, dimSize] : llvm::zip_first(multiDimIndex, shape)) {
91 std::optional<OutType> next = CheckAndConvert<OutType>::from(idx, dimSize);
92 if (!next.has_value()) {
95 ret.push_back(next.value());
100inline std::optional<int64_t> linearizeImpl(
101 ArrayRef<int64_t> multiDimIndex,
const ArrayRef<int64_t> &shape,
102 const SmallVector<int64_t> &strides
105 for (
auto [idx, dimSize] : llvm::zip_equal(multiDimIndex, shape)) {
106 if (!isInRange(idx, dimSize)) {
110 return mlir::linearize(multiDimIndex, strides);
113template <
typename TypeOfIndex>
114inline std::optional<int64_t> linearizeImpl(
115 ArrayRef<TypeOfIndex> multiDimIndex,
const ArrayRef<int64_t> &shape,
116 const SmallVector<int64_t> &strides
118 std::optional<SmallVector<int64_t>> conv =
119 checkAndConvertMulti<int64_t>(multiDimIndex, shape,
true );
120 if (!conv.has_value()) {
123 return mlir::linearize(conv.value(), strides);
126template <
typename ResultElemType>
127inline std::optional<SmallVector<ResultElemType>> delinearizeImpl(
128 int64_t linearIndex, int64_t linearSize,
const SmallVector<int64_t> &strides, MLIRContext *ctx,
129 llvm::function_ref<ResultElemType(IntegerAttr)> convert
131 if (!isInRange(linearIndex, linearSize)) {
134 SmallVector<ResultElemType> ret;
135 for (int64_t idx : mlir::delinearize(linearIndex, strides)) {
136 ret.push_back(convert(IntegerAttr::get(IndexType::get(ctx), idx)));
143std::optional<SmallVector<Value>>
145 return delinearizeImpl<Value>(
146 linearIndex, linearSize, strides, bldr.getContext(),
147 [&](IntegerAttr a) { return bldr.create<arith::ConstantOp>(loc, a); }
151std::optional<SmallVector<Attribute>>
153 return delinearizeImpl<Attribute>(linearIndex, linearSize, strides, ctx, [](IntegerAttr a) {
159 static_assert(
sizeof(InListType) == 0,
"linearize() not implemented for requested type.");
164 return linearizeImpl(multiDimIndex, shape, strides);
169 return linearizeImpl(multiDimIndex, shape, strides);
174 return linearizeImpl(multiDimIndex, shape, strides);
178 return linearizeImpl(multiDimIndex, shape, strides);
181template <
typename InListType>
183 static_assert(
sizeof(InListType) == 0,
"checkAndConvert() not implemented for requested type.");
189 return checkAndConvertMulti<Attribute>(multiDimIndex, shape,
false);
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
std::optional< llvm::SmallVector< mlir::Attribute > > checkAndConvert(InListType multiDimIndex)
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< int64_t > linearize(InListType multiDimIndex) const
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
int64_t fromAPInt(llvm::APInt i)