LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ArrayToScalarPass.cpp
Go to the documentation of this file.
1//===-- ArrayToScalarPass.cpp -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
17/// 2. Run a dialect conversion that does the following:
18///
19/// - Replace `FieldReadOp` and `FieldWriteOp` targeting the fields that were split in step 1 so
20/// they instead perform scalar reads and writes from the new fields. The transformation is
21/// local to the current op. Therefore, when replacing the `FieldReadOp` a new array is created
22/// locally and all uses of the `FieldReadOp` are replaced with the new array Value, then each
23/// scalar field read is followed by scalar write into the new array. Similarly, when replacing
24/// a `FieldWriteOp`, each element in the array operand needs a scalar read from the array
25/// followed by a scalar write to the new field. Making only local changes keeps this step
26/// simple and later steps will optimize.
27///
28/// - Replace `ArrayLengthOp` with the constant size of the selected dimension.
29///
30/// - Remove element initialization from `CreateArrayOp` and instead insert a list of
31/// `WriteArrayOp` immediately following.
32///
33/// - Desugar `InsertArrayOp` and `ExtractArrayOp` into their element-wise scalar reads/writes.
34///
35/// - Split arrays to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary
36/// create/read/write ops so the changes are as local as possible (just as described for
37/// `FieldReadOp` and `FieldWriteOp`)
38///
39/// 3. Run MLIR "sroa" pass to split each array with linear size `N` into `N` arrays of size 1 (to
40/// prepare for "mem2reg" pass because it's API does not allow for indexing to split aggregates).
41///
42/// 4. Run MLIR "mem2reg" pass to convert all of the size 1 array allocation and access into SSA
43/// values. This pass also runs several standard optimizations so the final result is condensed.
44///
45/// Note: This transformation imposes a "last write wins" semantics on array elements. If
46/// different/configurable semantics are added in the future, some additional transformation would
47/// be necessary before/during this pass so that multiple writes to the same index can be handled
48/// properly while they still exist.
49///
50/// Note: This transformation will introduce an undef op when there exists a read from an array
51/// index that was not earlier written to.
52///
53//===----------------------------------------------------------------------===//
54
63#include "llzk/Util/Concepts.h"
64
65#include <mlir/IR/BuiltinOps.h>
66#include <mlir/Pass/PassManager.h>
67#include <mlir/Transforms/DialectConversion.h>
68#include <mlir/Transforms/Passes.h>
69
70#include <llvm/Support/Debug.h>
71
72// Include the generated base pass class definitions.
73namespace llzk::array {
74#define GEN_PASS_DEF_ARRAYTOSCALARPASS
76} // namespace llzk::array
77
78using namespace mlir;
79using namespace llzk;
80using namespace llzk::array;
81using namespace llzk::component;
82using namespace llzk::function;
83
84#define DEBUG_TYPE "llzk-array-to-scalar"
85
86namespace {
87
89inline ArrayType splittableArray(ArrayType at) { return at.hasStaticShape() ? at : nullptr; }
90
92inline ArrayType splittableArray(Type t) {
93 if (ArrayType at = dyn_cast<ArrayType>(t)) {
94 return splittableArray(at);
95 } else {
96 return nullptr;
97 }
98}
99
101inline bool containsSplittableArrayType(Type t) {
102 return t
103 .walk([](ArrayType a) {
104 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
105 }).wasInterrupted();
106}
107
109template <typename T> bool containsSplittableArrayType(ValueTypeRange<T> types) {
110 for (Type t : types) {
111 if (containsSplittableArrayType(t)) {
112 return true;
113 }
114 }
115 return false;
116}
117
120size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
121 if (ArrayType at = splittableArray(t)) {
122 int64_t n = at.getNumElements();
123 assert(n >= 0);
124 assert(std::cmp_less_equal(n, std::numeric_limits<size_t>::max()));
125 size_t size = n;
126 collect.append(size, at.getElementType());
127 return size;
128 } else {
129 collect.push_back(t);
130 return 1;
131 }
132}
133
135template <typename TypeCollection>
136inline void splitArrayTypeTo(
137 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
138) {
139 for (Type t : types) {
140 size_t count = splitArrayTypeTo(t, collect);
141 if (originalIdxToSize) {
142 originalIdxToSize->push_back(count);
143 }
144 }
145}
146
149template <typename TypeCollection>
150inline SmallVector<Type>
151splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize = nullptr) {
152 SmallVector<Type> collect;
153 splitArrayTypeTo(types, collect, originalIdxToSize);
154 return collect;
155}
156
159SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
160 SmallVector<Value> operands;
161 for (Attribute a : index) {
162 // ASSERT: Attributes are index constants, created by ArrayType::getSubelementIndices().
163 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
164 assert(ia && ia.getType().isIndex());
165 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
166 }
167 return operands;
168}
169
170inline WriteArrayOp
171genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
172 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
173 return rewriter.create<WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
174}
175
179CallOp newCallOpWithSplitResults(
180 CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter
181) {
182 OpBuilder::InsertionGuard guard(rewriter);
183 rewriter.setInsertionPointAfter(oldCall);
184
185 Operation::result_range oldResults = oldCall.getResults();
186 CallOp newCall = rewriter.create<CallOp>(
187 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.getCallee(),
188 adaptor.getArgOperands()
189 );
190
191 auto newResults = newCall.getResults().begin();
192 for (Value oldVal : oldResults) {
193 if (ArrayType at = splittableArray(oldVal.getType())) {
194 Location loc = oldVal.getLoc();
195 // Generate `CreateArrayOp` and replace uses of the result with it.
196 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
197 rewriter.replaceAllUsesWith(oldVal, newArray);
198
199 // For all indices in the ArrayType (i.e., the element count), write the next
200 // result from the new CallOp to the new array.
201 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
202 assert(allIndices); // follows from legal() check
203 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
204 for (ArrayAttr subIdx : allIndices.value()) {
205 genWrite(loc, newArray, subIdx, *newResults, rewriter);
206 newResults++;
207 }
208 } else {
209 newResults++;
210 }
211 }
212 // erase the original CallOp
213 rewriter.eraseOp(oldCall);
214
215 return newCall;
216}
217
218inline ReadArrayOp
219genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
220 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
221 return rewriter.create<ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
222}
223
224// If the operand has ArrayType, add N reads from the array to the `newOperands` list otherwise add
225// the original operand to the list.
226void processInputOperand(
227 Location loc, Value operand, SmallVector<Value> &newOperands,
228 ConversionPatternRewriter &rewriter
229) {
230 if (ArrayType at = splittableArray(operand.getType())) {
231 std::optional<SmallVector<ArrayAttr>> indices = at.getSubelementIndices();
232 assert(indices.has_value() && "passed earlier hasStaticShape() check");
233 for (ArrayAttr index : indices.value()) {
234 newOperands.push_back(genRead(loc, operand, index, rewriter));
235 }
236 } else {
237 newOperands.push_back(operand);
238 }
239}
240
241// For each operand with ArrayType, add N reads from the array and use those N values instead.
242void processInputOperands(
243 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
244 ConversionPatternRewriter &rewriter
245) {
246 SmallVector<Value> newOperands;
247 for (Value v : operands) {
248 processInputOperand(op->getLoc(), v, newOperands, rewriter);
249 }
250 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
251 outputOpRef.assign(ValueRange(newOperands));
252 });
253}
254
255namespace {
256
257enum Direction {
259 SMALL_TO_LARGE,
261 LARGE_TO_SMALL,
262};
263
266template <Direction dir>
267inline void rewriteImpl(
268 ArrayAccessOpInterface op, ArrayType smallType, Value smallArr, Value largeArr,
269 ConversionPatternRewriter &rewriter
270) {
271 assert(smallType); // follows from legal() check
272 Location loc = op.getLoc();
273 MLIRContext *ctx = op.getContext();
274
275 ArrayAttr indexAsAttr = op.indexOperandsToAttributeArray();
276 assert(indexAsAttr); // follows from legal() check
277
278 // For all indices in the ArrayType (i.e., the element count), read from one array into the other
279 // (depending on direction flag).
280 std::optional<SmallVector<ArrayAttr>> subIndices = smallType.getSubelementIndices();
281 assert(subIndices); // follows from legal() check
282 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
283 for (ArrayAttr indexingTail : subIndices.value()) {
284 SmallVector<Attribute> joined;
285 joined.append(indexAsAttr.begin(), indexAsAttr.end());
286 joined.append(indexingTail.begin(), indexingTail.end());
287 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
288
289 if constexpr (dir == Direction::SMALL_TO_LARGE) {
290 auto init = genRead(loc, smallArr, indexingTail, rewriter);
291 genWrite(loc, largeArr, fullIndex, init, rewriter);
292 } else if constexpr (dir == Direction::LARGE_TO_SMALL) {
293 auto init = genRead(loc, largeArr, fullIndex, rewriter);
294 genWrite(loc, smallArr, indexingTail, init, rewriter);
295 }
296 }
297}
298
299} // namespace
300
301class SplitInsertArrayOp : public OpConversionPattern<InsertArrayOp> {
302public:
303 using OpConversionPattern<InsertArrayOp>::OpConversionPattern;
304
305 static bool legal(InsertArrayOp op) {
306 return !containsSplittableArrayType(op.getRvalue().getType());
307 }
308
309 LogicalResult match(InsertArrayOp op) const override { return failure(legal(op)); }
310
311 void
312 rewrite(InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
313 ArrayType at = splittableArray(op.getRvalue().getType());
314 rewriteImpl<SMALL_TO_LARGE>(
315 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
316 adaptor.getArrRef(), rewriter
317 );
318 rewriter.eraseOp(op);
319 }
320};
321
322class SplitExtractArrayOp : public OpConversionPattern<ExtractArrayOp> {
323public:
324 using OpConversionPattern<ExtractArrayOp>::OpConversionPattern;
325
326 static bool legal(ExtractArrayOp op) {
327 return !containsSplittableArrayType(op.getResult().getType());
328 }
329
330 LogicalResult match(ExtractArrayOp op) const override { return failure(legal(op)); }
331
332 void rewrite(ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
333 const override {
334 ArrayType at = splittableArray(op.getResult().getType());
335 // Generate `CreateArrayOp` in place of the current op.
336 auto newArray = rewriter.replaceOpWithNewOp<CreateArrayOp>(op, at);
337 rewriteImpl<LARGE_TO_SMALL>(
338 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
339 rewriter
340 );
341 }
342};
343
344class SplitInitFromCreateArrayOp : public OpConversionPattern<CreateArrayOp> {
345public:
346 using OpConversionPattern<CreateArrayOp>::OpConversionPattern;
347
348 static bool legal(CreateArrayOp op) { return op.getElements().empty(); }
349
350 LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); }
351
352 void
353 rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
354 // Remove elements from `op`
355 rewriter.modifyOpInPlace(op, [&op]() { op.getElementsMutable().clear(); });
356 // Generate an individual write for each initialization element
357 rewriter.setInsertionPointAfter(op);
358 Location loc = op.getLoc();
359 ArrayIndexGen idxGen = ArrayIndexGen::from(op.getType());
360 for (auto [i, init] : llvm::enumerate(adaptor.getElements())) {
361 // Convert the linear index 'i' into a multi-dim index
362 assert(std::cmp_less_equal(i, std::numeric_limits<int64_t>::max()));
363 std::optional<SmallVector<Value>> multiDimIdxVals =
364 idxGen.delinearize(static_cast<int64_t>(i), loc, rewriter);
365 // ASSERT: CreateArrayOp verifier ensures the number of elements provided matches the full
366 // linear array size so delinearization of `i` will not fail.
367 assert(multiDimIdxVals.has_value());
368 // Create the write
369 rewriter.create<WriteArrayOp>(loc, op.getResult(), ValueRange(*multiDimIdxVals), init);
370 }
371 }
372};
373
374class SplitArrayInFuncDefOp : public OpConversionPattern<FuncDefOp> {
375public:
376 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
377
378 inline static bool legal(FuncDefOp op) {
379 return !containsSplittableArrayType(op.getFunctionType());
380 }
381
382 // Create a new ArrayAttr like the one given but with repetitions of the elements according to the
383 // mapping defined by `originalIdxToSize`. In other words, if `originalIdxToSize[i] = n`, then `n`
384 // copies of `origAttrs[i]` are appended in its place.
385 static ArrayAttr replicateAttributesAsNeeded(
386 ArrayAttr origAttrs, const SmallVector<size_t> &originalIdxToSize,
387 const SmallVector<Type> &newTypes
388 ) {
389 if (origAttrs) {
390 assert(originalIdxToSize.size() == origAttrs.size());
391 if (originalIdxToSize.size() != newTypes.size()) {
392 SmallVector<Attribute> newArgAttrs;
393 for (auto [i, s] : llvm::enumerate(originalIdxToSize)) {
394 newArgAttrs.append(s, origAttrs[i]);
395 }
396 return ArrayAttr::get(origAttrs.getContext(), newArgAttrs);
397 }
398 }
399 return nullptr;
400 }
401
402 LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); }
403
404 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
405 // Update in/out types of the function to replace arrays with scalars
406 class Impl : public FunctionTypeConverter {
407 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
408
409 protected:
410 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
411 return splitArrayType(origTypes, &originalInputIdxToSize);
412 }
413 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
414 return splitArrayType(origTypes, &originalResultIdxToSize);
415 }
416 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
417 return replicateAttributesAsNeeded(origAttrs, originalInputIdxToSize, newTypes);
418 }
419 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
420 return replicateAttributesAsNeeded(origAttrs, originalResultIdxToSize, newTypes);
421 }
422
427 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
428 OpBuilder::InsertionGuard guard(rewriter);
429 rewriter.setInsertionPointToStart(&entryBlock);
430
431 for (unsigned i = 0; i < entryBlock.getNumArguments();) {
432 Value oldV = entryBlock.getArgument(i);
433 if (ArrayType at = splittableArray(oldV.getType())) {
434 Location loc = oldV.getLoc();
435 // Generate `CreateArrayOp` and replace uses of the argument with it.
436 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
437 rewriter.replaceAllUsesWith(oldV, newArray);
438 // Remove the argument from the block
439 entryBlock.eraseArgument(i);
440 // For all indices in the ArrayType (i.e., the element count), generate a new block
441 // argument and a write of that argument to the new array.
442 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
443 assert(allIndices); // follows from legal() check
444 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
445 for (ArrayAttr subIdx : allIndices.value()) {
446 BlockArgument newArg = entryBlock.insertArgument(i, at.getElementType(), loc);
447 genWrite(loc, newArray, subIdx, newArg, rewriter);
448 ++i;
449 }
450 } else {
451 ++i;
452 }
453 }
454 }
455 };
456 Impl().convert(op, rewriter);
457 }
458};
459
460class SplitArrayInReturnOp : public OpConversionPattern<ReturnOp> {
461public:
462 using OpConversionPattern<ReturnOp>::OpConversionPattern;
463
464 inline static bool legal(ReturnOp op) {
465 return !containsSplittableArrayType(op.getOperands().getTypes());
466 }
467
468 LogicalResult match(ReturnOp op) const override { return failure(legal(op)); }
469
470 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
471 processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter);
472 }
473};
474
475class SplitArrayInCallOp : public OpConversionPattern<CallOp> {
476public:
477 using OpConversionPattern<CallOp>::OpConversionPattern;
478
479 inline static bool legal(CallOp op) {
480 return !containsSplittableArrayType(op.getArgOperands().getTypes()) &&
481 !containsSplittableArrayType(op.getResultTypes());
482 }
483
484 LogicalResult match(CallOp op) const override { return failure(legal(op)); }
485
486 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
487 assert(isNullOrEmpty(op.getMapOpGroupSizesAttr()) && "structs must be previously flattened");
488
489 // Create new CallOp with split results first so, then process its inputs to split types
490 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
491 processInputOperands(
492 newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter
493 );
494 }
495};
496
497class ReplaceKnownArrayLengthOp : public OpConversionPattern<ArrayLengthOp> {
498public:
499 using OpConversionPattern<ArrayLengthOp>::OpConversionPattern;
500
502 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx, ArrayType baseArrType) {
503 if (splittableArray(baseArrType)) {
504 llvm::APInt idxAP;
505 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
506 uint64_t idx64 = idxAP.getZExtValue();
507 assert(std::cmp_less_equal(idx64, std::numeric_limits<size_t>::max()));
508 Attribute dimSizeAttr = baseArrType.getDimensionSizes()[static_cast<size_t>(idx64)];
509 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
510 return idxAP;
511 }
512 }
513 }
514 return std::nullopt;
515 }
516
517 inline static bool legal(ArrayLengthOp op) {
518 // rewrite() can only work with constant dim size, i.e., must consider it legal otherwise
519 return !getDimSizeIfKnown(op.getDim(), op.getArrRefType()).has_value();
520 }
521
522 LogicalResult match(ArrayLengthOp op) const override { return failure(legal(op)); }
523
524 void
525 rewrite(ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
526 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
527 assert(arrTy); // must have array type per ODS spec of ArrayLengthOp
528 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
529 assert(len.has_value()); // follows from legal() check
530 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, llzk::fromAPInt(len.value()));
531 }
532};
533
535using FieldInfo = std::pair<StringAttr, Type>;
537using LocalFieldReplacementMap = DenseMap<ArrayAttr, FieldInfo>;
539using FieldReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalFieldReplacementMap>>;
540
541class SplitArrayInFieldDefOp : public OpConversionPattern<FieldDefOp> {
542 SymbolTableCollection &tables;
543 FieldReplacementMap &repMapRef;
544
545public:
546 SplitArrayInFieldDefOp(
547 MLIRContext *ctx, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap
548 )
549 : OpConversionPattern<FieldDefOp>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
550
551 inline static bool legal(FieldDefOp op) { return !containsSplittableArrayType(op.getType()); }
552
553 LogicalResult match(FieldDefOp op) const override { return failure(legal(op)); }
554
555 void rewrite(FieldDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
556 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
557 assert(inStruct);
558 LocalFieldReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()];
559
560 ArrayType arrTy = dyn_cast<ArrayType>(op.getType());
561 assert(arrTy); // follows from legal() check
562 auto subIdxs = arrTy.getSubelementIndices();
563 assert(subIdxs.has_value());
564 Type elemTy = arrTy.getElementType();
565
566 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
567 for (ArrayAttr idx : subIdxs.value()) {
568 // Create scalar version of the field
569 FieldDefOp newField =
570 rewriter.create<FieldDefOp>(op.getLoc(), op.getSymNameAttr(), elemTy, op.getColumn());
571 newField.setPublicAttr(op.hasPublicAttr());
572 // Use SymbolTable to give it a unique name and store to the replacement map
573 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newField), elemTy);
574 }
575 rewriter.eraseOp(op);
576 }
577};
578
584template <
585 typename ImplClass, HasInterface<FieldRefOpInterface> FieldRefOpClass, typename GenHeaderType>
586class SplitArrayInFieldRefOp : public OpConversionPattern<FieldRefOpClass> {
587 SymbolTableCollection &tables;
588 const FieldReplacementMap &repMapRef;
589
590 // static check to ensure the functions are implemented in all subclasses
591 inline static void ensureImplementedAtCompile() {
592 static_assert(
593 sizeof(FieldRefOpClass) == 0, "SplitArrayInFieldRefOp not implemented for requested type."
594 );
595 }
596
597protected:
598 using OpAdaptor = typename FieldRefOpClass::Adaptor;
599
602 static GenHeaderType genHeader(FieldRefOpClass, ConversionPatternRewriter &) {
603 ensureImplementedAtCompile();
604 assert(false && "unreachable");
605 }
606
609 static void
610 forIndex(Location, GenHeaderType, ArrayAttr, FieldInfo, OpAdaptor, ConversionPatternRewriter &) {
611 ensureImplementedAtCompile();
612 assert(false && "unreachable");
613 }
614
615public:
616 SplitArrayInFieldRefOp(
617 MLIRContext *ctx, SymbolTableCollection &symTables, const FieldReplacementMap &fieldRepMap
618 )
619 : OpConversionPattern<FieldRefOpClass>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
620
621 static bool legal(FieldRefOpClass) {
622 ensureImplementedAtCompile();
623 assert(false && "unreachable");
624 }
625
626 LogicalResult match(FieldRefOpClass op) const override { return failure(ImplClass::legal(op)); }
627
628 void rewrite(FieldRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
629 const override {
630 StructType tgtStructTy = llvm::cast<FieldRefOpInterface>(op.getOperation()).getStructType();
631 assert(tgtStructTy);
632 auto tgtStructDef = tgtStructTy.getDefinition(tables, op);
633 assert(succeeded(tgtStructDef));
634
635 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
636
637 const LocalFieldReplacementMap &idxToName =
638 repMapRef.at(tgtStructDef->get()).at(op.getFieldNameAttr().getAttr());
639 // Split the array field write into a series of read array + write scalar field
640 for (auto [idx, newField] : idxToName) {
641 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newField, adaptor, rewriter);
642 }
643 rewriter.eraseOp(op);
644 }
645};
646
647class SplitArrayInFieldWriteOp
648 : public SplitArrayInFieldRefOp<SplitArrayInFieldWriteOp, FieldWriteOp, void *> {
649public:
650 using SplitArrayInFieldRefOp<
651 SplitArrayInFieldWriteOp, FieldWriteOp, void *>::SplitArrayInFieldRefOp;
652
653 static bool legal(FieldWriteOp op) { return !containsSplittableArrayType(op.getVal().getType()); }
654
655 static void *genHeader(FieldWriteOp, ConversionPatternRewriter &) { return nullptr; }
656
657 static void forIndex(
658 Location loc, void *, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
659 ConversionPatternRewriter &rewriter
660 ) {
661 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
662 rewriter.create<FieldWriteOp>(
663 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newField.first), scalarRead
664 );
665 }
666};
667
668class SplitArrayInFieldReadOp
669 : public SplitArrayInFieldRefOp<SplitArrayInFieldReadOp, FieldReadOp, CreateArrayOp> {
670public:
671 using SplitArrayInFieldRefOp<
672 SplitArrayInFieldReadOp, FieldReadOp, CreateArrayOp>::SplitArrayInFieldRefOp;
673
674 static bool legal(FieldReadOp op) {
675 return !containsSplittableArrayType(op.getResult().getType());
676 }
677
678 static CreateArrayOp genHeader(FieldReadOp op, ConversionPatternRewriter &rewriter) {
679 CreateArrayOp newArray =
680 rewriter.create<CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
681 rewriter.replaceAllUsesWith(op, newArray);
682 return newArray;
683 }
684
685 static void forIndex(
686 Location loc, CreateArrayOp newArray, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter
688 ) {
689 FieldReadOp scalarRead =
690 rewriter.create<FieldReadOp>(loc, newField.second, adaptor.getComponent(), newField.first);
691 genWrite(loc, newArray, idx, scalarRead, rewriter);
692 }
693};
694
695LogicalResult
696step1(ModuleOp modOp, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap) {
697 MLIRContext *ctx = modOp.getContext();
698
699 RewritePatternSet patterns(ctx);
700
701 patterns.add<SplitArrayInFieldDefOp>(ctx, symTables, fieldRepMap);
702
703 ConversionTarget target(*ctx);
704 target.addLegalDialect<
707 component::StructDialect, constrain::ConstrainDialect, arith::ArithDialect, scf::SCFDialect>(
708 );
709 target.addLegalOp<ModuleOp>();
710 target.addDynamicallyLegalOp<FieldDefOp>(SplitArrayInFieldDefOp::legal);
711
712 LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split array fields\n";);
713 return applyFullConversion(modOp, target, std::move(patterns));
714}
715
716LogicalResult
717step2(ModuleOp modOp, SymbolTableCollection &symTables, const FieldReplacementMap &fieldRepMap) {
718 MLIRContext *ctx = modOp.getContext();
719
720 RewritePatternSet patterns(ctx);
721 patterns.add<
722 // clang-format off
723 SplitInitFromCreateArrayOp,
724 SplitInsertArrayOp,
725 SplitExtractArrayOp,
726 SplitArrayInFuncDefOp,
727 SplitArrayInReturnOp,
728 SplitArrayInCallOp,
729 ReplaceKnownArrayLengthOp
730 // clang-format on
731 >(ctx);
732
733 patterns.add<
734 // clang-format off
735 SplitArrayInFieldWriteOp,
736 SplitArrayInFieldReadOp
737 // clang-format on
738 >(ctx, symTables, fieldRepMap);
739
740 ConversionTarget target(*ctx);
741 target.addLegalDialect<
745 scf::SCFDialect>();
746 target.addLegalOp<ModuleOp>();
747 target.addDynamicallyLegalOp<CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
748 target.addDynamicallyLegalOp<InsertArrayOp>(SplitInsertArrayOp::legal);
749 target.addDynamicallyLegalOp<ExtractArrayOp>(SplitExtractArrayOp::legal);
750 target.addDynamicallyLegalOp<FuncDefOp>(SplitArrayInFuncDefOp::legal);
751 target.addDynamicallyLegalOp<ReturnOp>(SplitArrayInReturnOp::legal);
752 target.addDynamicallyLegalOp<CallOp>(SplitArrayInCallOp::legal);
753 target.addDynamicallyLegalOp<ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
754 target.addDynamicallyLegalOp<FieldWriteOp>(SplitArrayInFieldWriteOp::legal);
755 target.addDynamicallyLegalOp<FieldReadOp>(SplitArrayInFieldReadOp::legal);
756
757 LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other array ops\n";);
758 return applyFullConversion(modOp, target, std::move(patterns));
759}
760
761LogicalResult splitArrayCreateInit(ModuleOp modOp) {
762 SymbolTableCollection symTables;
763 FieldReplacementMap fieldRepMap;
764
765 // This is divided into 2 steps to simplify the implementation for field-related ops. The issue is
766 // that the conversions for field read/write expect the mapping of array index to field name+type
767 // to already be populated for the referenced field (although this could be computed on demand if
768 // desired but it complicates the implementation a bit).
769 if (failed(step1(modOp, symTables, fieldRepMap))) {
770 return failure();
771 }
772 return step2(modOp, symTables, fieldRepMap);
773}
774
775class ArrayToScalarPass : public llzk::array::impl::ArrayToScalarPassBase<ArrayToScalarPass> {
776 void runOnOperation() override {
777 ModuleOp module = getOperation();
778 // Separate array initialization from creation by removing the initialization list from
779 // CreateArrayOp and inserting the corresponding WriteArrayOp following it.
780 if (failed(splitArrayCreateInit(module))) {
781 signalPassFailure();
782 return;
783 }
784 OpPassManager nestedPM(ModuleOp::getOperationName());
785 // Use SROA (Destructurable* interfaces) to split each array with linear size N into N arrays of
786 // size 1. This is necessary because the mem2reg pass cannot deal with indexing and splitting up
787 // memory, i.e., it can only convert scalar memory access into SSA values.
788 nestedPM.addPass(createSROA());
789 // The mem2reg pass converts all of the size 1 array allocation and access into SSA values.
790 nestedPM.addPass(createMem2Reg());
791 if (failed(runPipeline(nestedPM, module))) {
792 signalPassFailure();
793 return;
794 }
795 }
796};
797
798} // namespace
799
801 return std::make_unique<ArrayToScalarPass>();
802};
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Definition Ops.cpp:207
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::TypedValue<::mlir::IndexType > getDim()
Definition Ops.cpp.inc:140
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
Definition Ops.h.inc:158
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:114
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::MutableOperandRange getElementsMutable()
Definition Ops.cpp.inc:416
::mlir::Operation::operand_range getElements()
Definition Ops.cpp.inc:408
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.cpp.inc:438
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.cpp.inc:928
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
Definition Ops.cpp.inc:1184
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:399
::mlir::StringAttr getSymNameAttr()
Definition Ops.cpp.inc:576
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
::mlir::MutableOperandRange getArgOperandsMutable()
Definition Ops.cpp.inc:232
CallOpAdaptor Adaptor
Definition Ops.h.inc:180
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
Definition Ops.cpp.inc:540
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
::mlir::Operation::operand_range getOperands()
Definition Ops.cpp.inc:1322
::mlir::MutableOperandRange getOperandsMutable()
Definition Ops.cpp.inc:1326
Restricts a template parameter to Op classes that implement the given OpInterface.
Definition Concepts.h:20
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
int64_t fromAPInt(llvm::APInt i)