64#include <mlir/IR/BuiltinOps.h>
65#include <mlir/Pass/PassManager.h>
66#include <mlir/Transforms/DialectConversion.h>
67#include <mlir/Transforms/Passes.h>
69#include <llvm/Support/Debug.h>
73#define GEN_PASS_DEF_ARRAYTOSCALARPASS
83#define DEBUG_TYPE "llzk-array-to-scalar"
88inline ArrayType splittableArray(
ArrayType at) {
return at.hasStaticShape() ? at :
nullptr; }
92 if (
ArrayType at = dyn_cast<ArrayType>(t)) {
93 return splittableArray(at);
100inline bool containsSplittableArrayType(Type t) {
103 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
108template <
typename T>
bool containsSplittableArrayType(ValueTypeRange<T> types) {
109 for (Type t : types) {
110 if (containsSplittableArrayType(t)) {
119void splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
121 int64_t n = at.getNumElements();
122 assert(std::cmp_less_equal(n, std::numeric_limits<SmallVector<Type>::size_type>::max()));
125 collect.push_back(t);
130template <
typename TypeCollection>
131inline void splitArrayTypeTo(TypeCollection types, SmallVector<Type> &collect) {
132 for (Type t : types) {
133 splitArrayTypeTo(t, collect);
139template <
typename TypeCollection>
inline SmallVector<Type> splitArrayType(TypeCollection types) {
140 SmallVector<Type> collect;
141 splitArrayTypeTo(types, collect);
148genIndexConstants(ArrayAttr index, Location loc, ConversionPatternRewriter &rewriter) {
149 SmallVector<Value> operands;
150 for (Attribute a : index) {
152 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
153 assert(ia && ia.getType().isIndex());
154 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
160 Location loc, Value baseArrayOp, ArrayAttr index, Value init,
161 ConversionPatternRewriter &rewriter
163 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
164 return rewriter.create<
WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
170CallOp newCallOpWithSplitResults(
173 OpBuilder::InsertionGuard guard(rewriter);
174 rewriter.setInsertionPointAfter(oldCall);
176 Operation::result_range oldResults = oldCall.getResults();
178 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.
getCallee(),
182 auto newResults = newCall.getResults().begin();
183 for (Value oldVal : oldResults) {
184 if (
ArrayType at = splittableArray(oldVal.getType())) {
185 Location loc = oldVal.getLoc();
188 rewriter.replaceAllUsesWith(oldVal, newArray);
194 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
195 for (ArrayAttr subIdx : allIndices.value()) {
196 genWrite(loc, newArray, subIdx, *newResults, rewriter);
204 rewriter.eraseOp(oldCall);
213void processBlockArgs(Block &entryBlock, ConversionPatternRewriter &rewriter) {
214 OpBuilder::InsertionGuard guard(rewriter);
215 rewriter.setInsertionPointToStart(&entryBlock);
217 for (
unsigned i = 0; i < entryBlock.getNumArguments();) {
218 Value oldV = entryBlock.getArgument(i);
219 if (
ArrayType at = splittableArray(oldV.getType())) {
220 Location loc = oldV.getLoc();
223 rewriter.replaceAllUsesWith(oldV, newArray);
225 entryBlock.eraseArgument(i);
230 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
231 for (ArrayAttr subIdx : allIndices.value()) {
232 BlockArgument newArg = entryBlock.insertArgument(i, at.
getElementType(), loc);
233 genWrite(loc, newArray, subIdx, newArg, rewriter);
243genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
244 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
245 return rewriter.create<
ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
250void processInputOperand(
251 Location loc, Value operand, SmallVector<Value> &newOperands,
252 ConversionPatternRewriter &rewriter
254 if (
ArrayType at = splittableArray(operand.getType())) {
256 assert(indices.has_value() &&
"passed earlier hasStaticShape() check");
257 for (ArrayAttr index : indices.value()) {
258 newOperands.push_back(genRead(loc, operand, index, rewriter));
261 newOperands.push_back(operand);
266void processInputOperands(
267 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
268 ConversionPatternRewriter &rewriter
270 SmallVector<Value> newOperands;
271 for (Value v : operands) {
272 processInputOperand(op->getLoc(), v, newOperands, rewriter);
274 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
275 outputOpRef.assign(ValueRange(newOperands));
290template <Direction dir>
291inline void rewriteImpl(
293 ConversionPatternRewriter &rewriter
296 Location loc = op.getLoc();
297 MLIRContext *ctx = op.getContext();
306 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
307 for (ArrayAttr indexingTail : subIndices.value()) {
308 SmallVector<Attribute> joined;
309 joined.append(indexAsAttr.begin(), indexAsAttr.end());
310 joined.append(indexingTail.begin(), indexingTail.end());
311 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
313 if constexpr (dir == Direction::SMALL_TO_LARGE) {
314 auto init = genRead(loc, smallArr, indexingTail, rewriter);
315 genWrite(loc, largeArr, fullIndex, init, rewriter);
316 }
else if constexpr (dir == Direction::LARGE_TO_SMALL) {
317 auto init = genRead(loc, largeArr, fullIndex, rewriter);
318 genWrite(loc, smallArr, indexingTail, init, rewriter);
325class SplitInsertArrayOp :
public OpConversionPattern<InsertArrayOp> {
327 using OpConversionPattern<
InsertArrayOp>::OpConversionPattern;
330 return !containsSplittableArrayType(op.
getRvalue().getType());
333 LogicalResult match(
InsertArrayOp op)
const override {
return failure(legal(op)); }
336 rewrite(
InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
338 rewriteImpl<SMALL_TO_LARGE>(
339 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
340 adaptor.getArrRef(), rewriter
342 rewriter.eraseOp(op);
346class SplitExtractArrayOp :
public OpConversionPattern<ExtractArrayOp> {
351 return !containsSplittableArrayType(op.
getResult().getType());
354 LogicalResult match(
ExtractArrayOp op)
const override {
return failure(legal(op)); }
356 void rewrite(
ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
360 auto newArray = rewriter.replaceOpWithNewOp<
CreateArrayOp>(op, at);
361 rewriteImpl<LARGE_TO_SMALL>(
362 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
368class SplitInitFromCreateArrayOp :
public OpConversionPattern<CreateArrayOp> {
370 using OpConversionPattern<
CreateArrayOp>::OpConversionPattern;
374 LogicalResult match(
CreateArrayOp op)
const override {
return failure(legal(op)); }
377 rewrite(
CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
381 rewriter.setInsertionPointAfter(op);
382 Location loc = op.getLoc();
384 for (
auto [i, init] : llvm::enumerate(adaptor.getElements())) {
386 assert(std::cmp_less_equal(i, std::numeric_limits<int64_t>::max()));
387 std::optional<SmallVector<Value>> multiDimIdxVals =
388 idxGen.
delinearize(
static_cast<int64_t
>(i), loc, rewriter);
391 assert(multiDimIdxVals.has_value());
398class SplitArrayInFuncDefOp :
public OpConversionPattern<FuncDefOp> {
400 using OpConversionPattern<
FuncDefOp>::OpConversionPattern;
406 LogicalResult match(
FuncDefOp op)
const override {
return failure(legal(op)); }
408 void rewrite(
FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
411 SmallVector<Type> newInputs = splitArrayType(oldTy.getInputs());
412 SmallVector<Type> newOutputs = splitArrayType(oldTy.getResults());
414 FunctionType::get(oldTy.getContext(), TypeRange(newInputs), TypeRange(newOutputs));
415 if (newTy == oldTy) {
418 rewriter.modifyOpInPlace(op, [&op, &newTy]() { op.
setFunctionType(newTy); });
422 Block &entryBlock = body->front();
423 if (std::cmp_equal(entryBlock.getNumArguments(), newInputs.size())) {
426 processBlockArgs(entryBlock, rewriter);
431class SplitArrayInReturnOp :
public OpConversionPattern<ReturnOp> {
433 using OpConversionPattern<
ReturnOp>::OpConversionPattern;
435 inline static bool legal(
ReturnOp op) {
436 return !containsSplittableArrayType(op.
getOperands().getTypes());
439 LogicalResult match(
ReturnOp op)
const override {
return failure(legal(op)); }
441 void rewrite(
ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
446class SplitArrayInCallOp :
public OpConversionPattern<CallOp> {
448 using OpConversionPattern<
CallOp>::OpConversionPattern;
450 inline static bool legal(
CallOp op) {
451 return !containsSplittableArrayType(op.
getArgOperands().getTypes()) &&
452 !containsSplittableArrayType(op.getResultTypes());
455 LogicalResult match(
CallOp op)
const override {
return failure(legal(op)); }
457 void rewrite(
CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
461 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
462 processInputOperands(
468class ReplaceKnownArrayLengthOp :
public OpConversionPattern<ArrayLengthOp> {
470 using OpConversionPattern<
ArrayLengthOp>::OpConversionPattern;
473 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx,
ArrayType baseArrType) {
474 if (splittableArray(baseArrType)) {
476 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
477 uint64_t idx64 = idxAP.getZExtValue();
478 assert(std::cmp_less_equal(idx64, std::numeric_limits<size_t>::max()));
479 Attribute dimSizeAttr = baseArrType.
getDimensionSizes()[
static_cast<size_t>(idx64)];
480 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
493 LogicalResult match(
ArrayLengthOp op)
const override {
return failure(legal(op)); }
496 rewrite(
ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
497 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
499 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
500 assert(len.has_value());
501 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op,
llzk::fromAPInt(len.value()));
506using FieldInfo = std::pair<StringAttr, Type>;
508using LocalFieldReplacementMap = DenseMap<ArrayAttr, FieldInfo>;
510using FieldReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalFieldReplacementMap>>;
512class SplitArrayInFieldDefOp :
public OpConversionPattern<FieldDefOp> {
513 SymbolTableCollection &tables;
514 FieldReplacementMap &repMapRef;
517 SplitArrayInFieldDefOp(
518 MLIRContext *ctx, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap
520 : OpConversionPattern<FieldDefOp>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
522 inline static bool legal(
FieldDefOp op) {
return !containsSplittableArrayType(op.
getType()); }
524 LogicalResult match(
FieldDefOp op)
const override {
return failure(legal(op)); }
526 void rewrite(
FieldDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
529 LocalFieldReplacementMap &localRepMapRef = repMapRef[inStruct][op.
getSymNameAttr()];
534 assert(subIdxs.has_value());
537 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
538 for (ArrayAttr idx : subIdxs.value()) {
543 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newField), elemTy);
545 rewriter.eraseOp(op);
556class SplitArrayInFieldRefOp :
public OpConversionPattern<FieldRefOpClass> {
557 SymbolTableCollection &tables;
558 const FieldReplacementMap &repMapRef;
561 inline static void ensureImplementedAtCompile() {
563 sizeof(FieldRefOpClass) == 0,
"SplitArrayInFieldRefOp not implemented for requested type."
568 using OpAdaptor =
typename FieldRefOpClass::Adaptor;
572 static GenHeaderType genHeader(FieldRefOpClass, ConversionPatternRewriter &) {
573 ensureImplementedAtCompile();
574 assert(
false &&
"unreachable");
580 forIndex(Location, GenHeaderType, ArrayAttr, FieldInfo, OpAdaptor, ConversionPatternRewriter &) {
581 ensureImplementedAtCompile();
582 assert(
false &&
"unreachable");
586 SplitArrayInFieldRefOp(
587 MLIRContext *ctx, SymbolTableCollection &symTables,
const FieldReplacementMap &fieldRepMap
589 : OpConversionPattern<FieldRefOpClass>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
591 static bool legal(FieldRefOpClass) {
592 ensureImplementedAtCompile();
593 assert(
false &&
"unreachable");
596 LogicalResult match(FieldRefOpClass op)
const override {
return failure(ImplClass::legal(op)); }
598 void rewrite(FieldRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
600 StructType tgtStructTy = llvm::cast<FieldRefOpInterface>(op.getOperation()).getStructType();
603 assert(succeeded(tgtStructDef));
605 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
607 const LocalFieldReplacementMap &idxToName =
608 repMapRef.at(tgtStructDef->get()).at(op.getFieldNameAttr().getAttr());
610 for (
auto [idx, newField] : idxToName) {
611 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newField, adaptor, rewriter);
613 rewriter.eraseOp(op);
617class SplitArrayInFieldWriteOp
618 :
public SplitArrayInFieldRefOp<SplitArrayInFieldWriteOp, FieldWriteOp, void *> {
620 using SplitArrayInFieldRefOp<
621 SplitArrayInFieldWriteOp,
FieldWriteOp,
void *>::SplitArrayInFieldRefOp;
623 static bool legal(
FieldWriteOp op) {
return !containsSplittableArrayType(op.
getVal().getType()); }
625 static void *genHeader(
FieldWriteOp, ConversionPatternRewriter &) {
return nullptr; }
627 static void forIndex(
628 Location loc,
void *, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
629 ConversionPatternRewriter &rewriter
631 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
633 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newField.first), scalarRead
638class SplitArrayInFieldReadOp
639 :
public SplitArrayInFieldRefOp<SplitArrayInFieldReadOp, FieldReadOp, CreateArrayOp> {
641 using SplitArrayInFieldRefOp<
645 return !containsSplittableArrayType(op.getResult().getType());
650 rewriter.create<
CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
651 rewriter.replaceAllUsesWith(op, newArray);
655 static void forIndex(
656 Location loc,
CreateArrayOp newArray, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
657 ConversionPatternRewriter &rewriter
660 rewriter.create<
FieldReadOp>(loc, newField.second, adaptor.getComponent(), newField.first);
661 genWrite(loc, newArray, idx, scalarRead, rewriter);
666step1(ModuleOp modOp, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap) {
667 MLIRContext *ctx = modOp.getContext();
669 RewritePatternSet patterns(ctx);
671 patterns.add<SplitArrayInFieldDefOp>(ctx, symTables, fieldRepMap);
673 ConversionTarget target(*ctx);
674 target.addLegalDialect<
679 target.addLegalOp<ModuleOp>();
680 target.addDynamicallyLegalOp<
FieldDefOp>(SplitArrayInFieldDefOp::legal);
682 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 1: split array fields\n";);
683 return applyFullConversion(modOp, target, std::move(patterns));
687step2(ModuleOp modOp, SymbolTableCollection &symTables,
const FieldReplacementMap &fieldRepMap) {
688 MLIRContext *ctx = modOp.getContext();
690 RewritePatternSet patterns(ctx);
693 SplitInitFromCreateArrayOp,
696 SplitArrayInFuncDefOp,
697 SplitArrayInReturnOp,
699 ReplaceKnownArrayLengthOp
705 SplitArrayInFieldWriteOp,
706 SplitArrayInFieldReadOp
708 >(ctx, symTables, fieldRepMap);
710 ConversionTarget target(*ctx);
711 target.addLegalDialect<
716 target.addLegalOp<ModuleOp>();
717 target.addDynamicallyLegalOp<
CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
718 target.addDynamicallyLegalOp<
InsertArrayOp>(SplitInsertArrayOp::legal);
719 target.addDynamicallyLegalOp<
ExtractArrayOp>(SplitExtractArrayOp::legal);
720 target.addDynamicallyLegalOp<
FuncDefOp>(SplitArrayInFuncDefOp::legal);
721 target.addDynamicallyLegalOp<
ReturnOp>(SplitArrayInReturnOp::legal);
722 target.addDynamicallyLegalOp<
CallOp>(SplitArrayInCallOp::legal);
723 target.addDynamicallyLegalOp<
ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
724 target.addDynamicallyLegalOp<
FieldWriteOp>(SplitArrayInFieldWriteOp::legal);
725 target.addDynamicallyLegalOp<
FieldReadOp>(SplitArrayInFieldReadOp::legal);
727 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 2: update/split other array ops\n";);
728 return applyFullConversion(modOp, target, std::move(patterns));
731LogicalResult splitArrayCreateInit(ModuleOp modOp) {
732 SymbolTableCollection symTables;
733 FieldReplacementMap fieldRepMap;
739 if (failed(step1(modOp, symTables, fieldRepMap))) {
742 return step2(modOp, symTables, fieldRepMap);
746 void runOnOperation()
override {
747 ModuleOp module = getOperation();
750 if (failed(splitArrayCreateInit(module))) {
754 OpPassManager nestedPM(ModuleOp::getOperationName());
758 nestedPM.addPass(createSROA());
760 nestedPM.addPass(createMem2Reg());
761 if (failed(runPipeline(nestedPM, module))) {
771 return std::make_unique<ArrayToScalarPass>();
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::TypedValue<::mlir::IndexType > getDim()
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::MutableOperandRange getElementsMutable()
::mlir::Operation::operand_range getElements()
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
::mlir::StringAttr getSymNameAttr()
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
::mlir::SymbolRefAttr getCallee()
::mlir::MutableOperandRange getArgOperandsMutable()
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
::mlir::Operation::operand_range getArgOperands()
::mlir::FunctionType getFunctionType()
void setFunctionType(::mlir::FunctionType attrValue)
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
::mlir::Operation::operand_range getOperands()
::mlir::MutableOperandRange getOperandsMutable()
Restricts a template parameter to Op classes that implement the given OpInterface.
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
int64_t fromAPInt(llvm::APInt i)