65#include <mlir/IR/BuiltinOps.h>
66#include <mlir/Pass/PassManager.h>
67#include <mlir/Transforms/DialectConversion.h>
68#include <mlir/Transforms/Passes.h>
70#include <llvm/Support/Debug.h>
74#define GEN_PASS_DEF_ARRAYTOSCALARPASS
84#define DEBUG_TYPE "llzk-array-to-scalar"
89inline ArrayType splittableArray(
ArrayType at) {
return at.hasStaticShape() ? at :
nullptr; }
93 if (
ArrayType at = dyn_cast<ArrayType>(t)) {
94 return splittableArray(at);
101inline bool containsSplittableArrayType(Type t) {
104 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
109template <
typename T>
bool containsSplittableArrayType(ValueTypeRange<T> types) {
110 for (Type t : types) {
111 if (containsSplittableArrayType(t)) {
120size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
122 int64_t n = at.getNumElements();
124 assert(std::cmp_less_equal(n, std::numeric_limits<size_t>::max()));
129 collect.push_back(t);
135template <
typename TypeCollection>
136inline void splitArrayTypeTo(
137 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
139 for (Type t : types) {
140 size_t count = splitArrayTypeTo(t, collect);
141 if (originalIdxToSize) {
142 originalIdxToSize->push_back(count);
149template <
typename TypeCollection>
150inline SmallVector<Type>
151splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize =
nullptr) {
152 SmallVector<Type> collect;
153 splitArrayTypeTo(types, collect, originalIdxToSize);
159SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
160 SmallVector<Value> operands;
161 for (Attribute a : index) {
163 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
164 assert(ia && ia.getType().isIndex());
165 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
171genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
172 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
173 return rewriter.create<
WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
179CallOp newCallOpWithSplitResults(
182 OpBuilder::InsertionGuard guard(rewriter);
183 rewriter.setInsertionPointAfter(oldCall);
185 Operation::result_range oldResults = oldCall.getResults();
187 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.
getCallee(),
191 auto newResults = newCall.getResults().begin();
192 for (Value oldVal : oldResults) {
193 if (
ArrayType at = splittableArray(oldVal.getType())) {
194 Location loc = oldVal.getLoc();
197 rewriter.replaceAllUsesWith(oldVal, newArray);
203 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
204 for (ArrayAttr subIdx : allIndices.value()) {
205 genWrite(loc, newArray, subIdx, *newResults, rewriter);
213 rewriter.eraseOp(oldCall);
219genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
220 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
221 return rewriter.create<
ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
226void processInputOperand(
227 Location loc, Value operand, SmallVector<Value> &newOperands,
228 ConversionPatternRewriter &rewriter
230 if (
ArrayType at = splittableArray(operand.getType())) {
232 assert(indices.has_value() &&
"passed earlier hasStaticShape() check");
233 for (ArrayAttr index : indices.value()) {
234 newOperands.push_back(genRead(loc, operand, index, rewriter));
237 newOperands.push_back(operand);
242void processInputOperands(
243 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
244 ConversionPatternRewriter &rewriter
246 SmallVector<Value> newOperands;
247 for (Value v : operands) {
248 processInputOperand(op->getLoc(), v, newOperands, rewriter);
250 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
251 outputOpRef.assign(ValueRange(newOperands));
266template <Direction dir>
267inline void rewriteImpl(
269 ConversionPatternRewriter &rewriter
272 Location loc = op.getLoc();
273 MLIRContext *ctx = op.getContext();
282 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
283 for (ArrayAttr indexingTail : subIndices.value()) {
284 SmallVector<Attribute> joined;
285 joined.append(indexAsAttr.begin(), indexAsAttr.end());
286 joined.append(indexingTail.begin(), indexingTail.end());
287 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
289 if constexpr (dir == Direction::SMALL_TO_LARGE) {
290 auto init = genRead(loc, smallArr, indexingTail, rewriter);
291 genWrite(loc, largeArr, fullIndex, init, rewriter);
292 }
else if constexpr (dir == Direction::LARGE_TO_SMALL) {
293 auto init = genRead(loc, largeArr, fullIndex, rewriter);
294 genWrite(loc, smallArr, indexingTail, init, rewriter);
301class SplitInsertArrayOp :
public OpConversionPattern<InsertArrayOp> {
303 using OpConversionPattern<
InsertArrayOp>::OpConversionPattern;
306 return !containsSplittableArrayType(op.
getRvalue().getType());
309 LogicalResult match(
InsertArrayOp op)
const override {
return failure(legal(op)); }
312 rewrite(
InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
314 rewriteImpl<SMALL_TO_LARGE>(
315 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
316 adaptor.getArrRef(), rewriter
318 rewriter.eraseOp(op);
322class SplitExtractArrayOp :
public OpConversionPattern<ExtractArrayOp> {
327 return !containsSplittableArrayType(op.
getResult().getType());
330 LogicalResult match(
ExtractArrayOp op)
const override {
return failure(legal(op)); }
332 void rewrite(
ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
336 auto newArray = rewriter.replaceOpWithNewOp<
CreateArrayOp>(op, at);
337 rewriteImpl<LARGE_TO_SMALL>(
338 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
344class SplitInitFromCreateArrayOp :
public OpConversionPattern<CreateArrayOp> {
346 using OpConversionPattern<
CreateArrayOp>::OpConversionPattern;
350 LogicalResult match(
CreateArrayOp op)
const override {
return failure(legal(op)); }
353 rewrite(
CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
357 rewriter.setInsertionPointAfter(op);
358 Location loc = op.getLoc();
360 for (
auto [i, init] : llvm::enumerate(adaptor.getElements())) {
362 assert(std::cmp_less_equal(i, std::numeric_limits<int64_t>::max()));
363 std::optional<SmallVector<Value>> multiDimIdxVals =
364 idxGen.
delinearize(
static_cast<int64_t
>(i), loc, rewriter);
367 assert(multiDimIdxVals.has_value());
374class SplitArrayInFuncDefOp :
public OpConversionPattern<FuncDefOp> {
376 using OpConversionPattern<
FuncDefOp>::OpConversionPattern;
385 static ArrayAttr replicateAttributesAsNeeded(
386 ArrayAttr origAttrs,
const SmallVector<size_t> &originalIdxToSize,
387 const SmallVector<Type> &newTypes
390 assert(originalIdxToSize.size() == origAttrs.size());
391 if (originalIdxToSize.size() != newTypes.size()) {
392 SmallVector<Attribute> newArgAttrs;
393 for (
auto [i, s] : llvm::enumerate(originalIdxToSize)) {
394 newArgAttrs.append(s, origAttrs[i]);
396 return ArrayAttr::get(origAttrs.getContext(), newArgAttrs);
402 LogicalResult match(
FuncDefOp op)
const override {
return failure(legal(op)); }
404 void rewrite(
FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
407 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
410 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes)
override {
411 return splitArrayType(origTypes, &originalInputIdxToSize);
413 SmallVector<Type> convertResults(ArrayRef<Type> origTypes)
override {
414 return splitArrayType(origTypes, &originalResultIdxToSize);
416 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
417 return replicateAttributesAsNeeded(origAttrs, originalInputIdxToSize, newTypes);
419 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
420 return replicateAttributesAsNeeded(origAttrs, originalResultIdxToSize, newTypes);
427 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter)
override {
428 OpBuilder::InsertionGuard guard(rewriter);
429 rewriter.setInsertionPointToStart(&entryBlock);
431 for (
unsigned i = 0; i < entryBlock.getNumArguments();) {
432 Value oldV = entryBlock.getArgument(i);
433 if (
ArrayType at = splittableArray(oldV.getType())) {
434 Location loc = oldV.getLoc();
437 rewriter.replaceAllUsesWith(oldV, newArray);
439 entryBlock.eraseArgument(i);
444 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
445 for (ArrayAttr subIdx : allIndices.value()) {
446 BlockArgument newArg = entryBlock.insertArgument(i, at.
getElementType(), loc);
447 genWrite(loc, newArray, subIdx, newArg, rewriter);
456 Impl().convert(op, rewriter);
460class SplitArrayInReturnOp :
public OpConversionPattern<ReturnOp> {
462 using OpConversionPattern<
ReturnOp>::OpConversionPattern;
464 inline static bool legal(
ReturnOp op) {
465 return !containsSplittableArrayType(op.
getOperands().getTypes());
468 LogicalResult match(
ReturnOp op)
const override {
return failure(legal(op)); }
470 void rewrite(
ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
475class SplitArrayInCallOp :
public OpConversionPattern<CallOp> {
477 using OpConversionPattern<
CallOp>::OpConversionPattern;
479 inline static bool legal(
CallOp op) {
480 return !containsSplittableArrayType(op.
getArgOperands().getTypes()) &&
481 !containsSplittableArrayType(op.getResultTypes());
484 LogicalResult match(
CallOp op)
const override {
return failure(legal(op)); }
486 void rewrite(
CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
490 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
491 processInputOperands(
497class ReplaceKnownArrayLengthOp :
public OpConversionPattern<ArrayLengthOp> {
499 using OpConversionPattern<
ArrayLengthOp>::OpConversionPattern;
502 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx,
ArrayType baseArrType) {
503 if (splittableArray(baseArrType)) {
505 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
506 uint64_t idx64 = idxAP.getZExtValue();
507 assert(std::cmp_less_equal(idx64, std::numeric_limits<size_t>::max()));
508 Attribute dimSizeAttr = baseArrType.
getDimensionSizes()[
static_cast<size_t>(idx64)];
509 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
522 LogicalResult match(
ArrayLengthOp op)
const override {
return failure(legal(op)); }
525 rewrite(
ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
526 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
528 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
529 assert(len.has_value());
530 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op,
llzk::fromAPInt(len.value()));
535using FieldInfo = std::pair<StringAttr, Type>;
537using LocalFieldReplacementMap = DenseMap<ArrayAttr, FieldInfo>;
539using FieldReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalFieldReplacementMap>>;
541class SplitArrayInFieldDefOp :
public OpConversionPattern<FieldDefOp> {
542 SymbolTableCollection &tables;
543 FieldReplacementMap &repMapRef;
546 SplitArrayInFieldDefOp(
547 MLIRContext *ctx, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap
549 : OpConversionPattern<FieldDefOp>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
551 inline static bool legal(
FieldDefOp op) {
return !containsSplittableArrayType(op.
getType()); }
553 LogicalResult match(
FieldDefOp op)
const override {
return failure(legal(op)); }
555 void rewrite(
FieldDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
558 LocalFieldReplacementMap &localRepMapRef = repMapRef[inStruct][op.
getSymNameAttr()];
563 assert(subIdxs.has_value());
566 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
567 for (ArrayAttr idx : subIdxs.value()) {
573 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newField), elemTy);
575 rewriter.eraseOp(op);
586class SplitArrayInFieldRefOp :
public OpConversionPattern<FieldRefOpClass> {
587 SymbolTableCollection &tables;
588 const FieldReplacementMap &repMapRef;
591 inline static void ensureImplementedAtCompile() {
593 sizeof(FieldRefOpClass) == 0,
"SplitArrayInFieldRefOp not implemented for requested type."
598 using OpAdaptor =
typename FieldRefOpClass::Adaptor;
602 static GenHeaderType genHeader(FieldRefOpClass, ConversionPatternRewriter &) {
603 ensureImplementedAtCompile();
604 assert(
false &&
"unreachable");
610 forIndex(Location, GenHeaderType, ArrayAttr, FieldInfo, OpAdaptor, ConversionPatternRewriter &) {
611 ensureImplementedAtCompile();
612 assert(
false &&
"unreachable");
616 SplitArrayInFieldRefOp(
617 MLIRContext *ctx, SymbolTableCollection &symTables,
const FieldReplacementMap &fieldRepMap
619 : OpConversionPattern<FieldRefOpClass>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
621 static bool legal(FieldRefOpClass) {
622 ensureImplementedAtCompile();
623 assert(
false &&
"unreachable");
626 LogicalResult match(FieldRefOpClass op)
const override {
return failure(ImplClass::legal(op)); }
628 void rewrite(FieldRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
630 StructType tgtStructTy = llvm::cast<FieldRefOpInterface>(op.getOperation()).getStructType();
633 assert(succeeded(tgtStructDef));
635 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
637 const LocalFieldReplacementMap &idxToName =
638 repMapRef.at(tgtStructDef->get()).at(op.getFieldNameAttr().getAttr());
640 for (
auto [idx, newField] : idxToName) {
641 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newField, adaptor, rewriter);
643 rewriter.eraseOp(op);
647class SplitArrayInFieldWriteOp
648 :
public SplitArrayInFieldRefOp<SplitArrayInFieldWriteOp, FieldWriteOp, void *> {
650 using SplitArrayInFieldRefOp<
651 SplitArrayInFieldWriteOp,
FieldWriteOp,
void *>::SplitArrayInFieldRefOp;
653 static bool legal(
FieldWriteOp op) {
return !containsSplittableArrayType(op.
getVal().getType()); }
655 static void *genHeader(
FieldWriteOp, ConversionPatternRewriter &) {
return nullptr; }
657 static void forIndex(
658 Location loc,
void *, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
659 ConversionPatternRewriter &rewriter
661 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
663 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newField.first), scalarRead
668class SplitArrayInFieldReadOp
669 :
public SplitArrayInFieldRefOp<SplitArrayInFieldReadOp, FieldReadOp, CreateArrayOp> {
671 using SplitArrayInFieldRefOp<
675 return !containsSplittableArrayType(op.getResult().getType());
680 rewriter.create<
CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
681 rewriter.replaceAllUsesWith(op, newArray);
685 static void forIndex(
686 Location loc,
CreateArrayOp newArray, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter
690 rewriter.create<
FieldReadOp>(loc, newField.second, adaptor.getComponent(), newField.first);
691 genWrite(loc, newArray, idx, scalarRead, rewriter);
696step1(ModuleOp modOp, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap) {
697 MLIRContext *ctx = modOp.getContext();
699 RewritePatternSet patterns(ctx);
701 patterns.add<SplitArrayInFieldDefOp>(ctx, symTables, fieldRepMap);
703 ConversionTarget target(*ctx);
704 target.addLegalDialect<
709 target.addLegalOp<ModuleOp>();
710 target.addDynamicallyLegalOp<
FieldDefOp>(SplitArrayInFieldDefOp::legal);
712 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 1: split array fields\n";);
713 return applyFullConversion(modOp, target, std::move(patterns));
717step2(ModuleOp modOp, SymbolTableCollection &symTables,
const FieldReplacementMap &fieldRepMap) {
718 MLIRContext *ctx = modOp.getContext();
720 RewritePatternSet patterns(ctx);
723 SplitInitFromCreateArrayOp,
726 SplitArrayInFuncDefOp,
727 SplitArrayInReturnOp,
729 ReplaceKnownArrayLengthOp
735 SplitArrayInFieldWriteOp,
736 SplitArrayInFieldReadOp
738 >(ctx, symTables, fieldRepMap);
740 ConversionTarget target(*ctx);
741 target.addLegalDialect<
746 target.addLegalOp<ModuleOp>();
747 target.addDynamicallyLegalOp<
CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
748 target.addDynamicallyLegalOp<
InsertArrayOp>(SplitInsertArrayOp::legal);
749 target.addDynamicallyLegalOp<
ExtractArrayOp>(SplitExtractArrayOp::legal);
750 target.addDynamicallyLegalOp<
FuncDefOp>(SplitArrayInFuncDefOp::legal);
751 target.addDynamicallyLegalOp<
ReturnOp>(SplitArrayInReturnOp::legal);
752 target.addDynamicallyLegalOp<
CallOp>(SplitArrayInCallOp::legal);
753 target.addDynamicallyLegalOp<
ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
754 target.addDynamicallyLegalOp<
FieldWriteOp>(SplitArrayInFieldWriteOp::legal);
755 target.addDynamicallyLegalOp<
FieldReadOp>(SplitArrayInFieldReadOp::legal);
757 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 2: update/split other array ops\n";);
758 return applyFullConversion(modOp, target, std::move(patterns));
761LogicalResult splitArrayCreateInit(ModuleOp modOp) {
762 SymbolTableCollection symTables;
763 FieldReplacementMap fieldRepMap;
769 if (failed(step1(modOp, symTables, fieldRepMap))) {
772 return step2(modOp, symTables, fieldRepMap);
776 void runOnOperation()
override {
777 ModuleOp module = getOperation();
780 if (failed(splitArrayCreateInit(module))) {
784 OpPassManager nestedPM(ModuleOp::getOperationName());
788 nestedPM.addPass(createSROA());
790 nestedPM.addPass(createMem2Reg());
791 if (failed(runPipeline(nestedPM, module))) {
801 return std::make_unique<ArrayToScalarPass>();
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::TypedValue<::mlir::IndexType > getDim()
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::MutableOperandRange getElementsMutable()
::mlir::Operation::operand_range getElements()
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
void setPublicAttr(bool newValue=true)
::mlir::StringAttr getSymNameAttr()
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
::mlir::SymbolRefAttr getCallee()
::mlir::MutableOperandRange getArgOperandsMutable()
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
::mlir::Operation::operand_range getArgOperands()
::mlir::FunctionType getFunctionType()
::mlir::Operation::operand_range getOperands()
::mlir::MutableOperandRange getOperandsMutable()
Restricts a template parameter to Op classes that implement the given OpInterface.
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
int64_t fromAPInt(llvm::APInt i)