LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
ArrayToScalarPass.cpp
Go to the documentation of this file.
1//===-- ArrayToScalarPass.cpp -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
17/// 2. Run a dialect conversion that does the following:
18///
19/// - Replace `FieldReadOp` and `FieldWriteOp` targeting the fields that were split in step 1 so
20/// they instead perform scalar reads and writes from the new fields. The transformation is
21/// local to the current op. Therefore, when replacing the `FieldReadOp` a new array is created
22/// locally and all uses of the `FieldReadOp` are replaced with the new array Value, then each
23/// scalar field read is followed by scalar write into the new array. Similarly, when replacing
24/// a `FieldWriteOp`, each element in the array operand needs a scalar read from the array
25/// followed by a scalar write to the new field. Making only local changes keeps this step
26/// simple and later steps will optimize.
27///
28/// - Replace `ArrayLengthOp` with the constant size of the selected dimension.
29///
30/// - Remove element initialization from `CreateArrayOp` and instead insert a list of
31/// `WriteArrayOp` immediately following.
32///
33/// - Desugar `InsertArrayOp` and `ExtractArrayOp` into their element-wise scalar reads/writes.
34///
35/// - Split arrays to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary
36/// create/read/write ops so the changes are as local as possible (just as described for
37/// `FieldReadOp` and `FieldWriteOp`)
38///
39/// 3. Run MLIR "sroa" pass to split each array with linear size `N` into `N` arrays of size 1 (to
40/// prepare for "mem2reg" pass because it's API does not allow for indexing to split aggregates).
41///
42/// 4. Run MLIR "mem2reg" pass to convert all of the size 1 array allocation and access into SSA
43/// values. This pass also runs several standard optimizations so the final result is condensed.
44///
45/// Note: This transformation imposes a "last write wins" semantics on array elements. If
46/// different/configurable semantics are added in the future, some additional transformation would
47/// be necessary before/during this pass so that multiple writes to the same index can be handled
48/// properly while they still exist.
49///
50/// Note: This transformation will introduce an undef op when there exists a read from an array
51/// index that was not earlier written to.
52///
53//===----------------------------------------------------------------------===//
54
62#include "llzk/Util/Concepts.h"
63
64#include <mlir/IR/BuiltinOps.h>
65#include <mlir/Pass/PassManager.h>
66#include <mlir/Transforms/DialectConversion.h>
67#include <mlir/Transforms/Passes.h>
68
69#include <llvm/Support/Debug.h>
70
71// Include the generated base pass class definitions.
72namespace llzk::array {
73#define GEN_PASS_DEF_ARRAYTOSCALARPASS
75} // namespace llzk::array
76
77using namespace mlir;
78using namespace llzk;
79using namespace llzk::array;
80using namespace llzk::component;
81using namespace llzk::function;
82
83#define DEBUG_TYPE "llzk-array-to-scalar"
84
85namespace {
86
88inline ArrayType splittableArray(ArrayType at) { return at.hasStaticShape() ? at : nullptr; }
89
91inline ArrayType splittableArray(Type t) {
92 if (ArrayType at = dyn_cast<ArrayType>(t)) {
93 return splittableArray(at);
94 } else {
95 return nullptr;
96 }
97}
98
100inline bool containsSplittableArrayType(Type t) {
101 return t
102 .walk([](ArrayType a) {
103 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
104 }).wasInterrupted();
105}
106
108template <typename T> bool containsSplittableArrayType(ValueTypeRange<T> types) {
109 for (Type t : types) {
110 if (containsSplittableArrayType(t)) {
111 return true;
112 }
113 }
114 return false;
115}
116
119void splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
120 if (ArrayType at = splittableArray(t)) {
121 int64_t n = at.getNumElements();
122 assert(std::cmp_less_equal(n, std::numeric_limits<SmallVector<Type>::size_type>::max()));
123 collect.append(n, at.getElementType());
124 } else {
125 collect.push_back(t);
126 }
127}
128
130template <typename TypeCollection>
131inline void splitArrayTypeTo(TypeCollection types, SmallVector<Type> &collect) {
132 for (Type t : types) {
133 splitArrayTypeTo(t, collect);
134 }
135}
136
139template <typename TypeCollection> inline SmallVector<Type> splitArrayType(TypeCollection types) {
140 SmallVector<Type> collect;
141 splitArrayTypeTo(types, collect);
142 return collect;
143}
144
147SmallVector<Value>
148genIndexConstants(ArrayAttr index, Location loc, ConversionPatternRewriter &rewriter) {
149 SmallVector<Value> operands;
150 for (Attribute a : index) {
151 // ASSERT: Attributes are index constants, created by ArrayType::getSubelementIndices().
152 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
153 assert(ia && ia.getType().isIndex());
154 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
155 }
156 return operands;
157}
158
159inline WriteArrayOp genWrite(
160 Location loc, Value baseArrayOp, ArrayAttr index, Value init,
161 ConversionPatternRewriter &rewriter
162) {
163 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
164 return rewriter.create<WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
165}
166
170CallOp newCallOpWithSplitResults(
171 CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter
172) {
173 OpBuilder::InsertionGuard guard(rewriter);
174 rewriter.setInsertionPointAfter(oldCall);
175
176 Operation::result_range oldResults = oldCall.getResults();
177 CallOp newCall = rewriter.create<CallOp>(
178 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.getCallee(),
179 adaptor.getArgOperands()
180 );
181
182 auto newResults = newCall.getResults().begin();
183 for (Value oldVal : oldResults) {
184 if (ArrayType at = splittableArray(oldVal.getType())) {
185 Location loc = oldVal.getLoc();
186 // Generate `CreateArrayOp` and replace uses of the result with it.
187 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
188 rewriter.replaceAllUsesWith(oldVal, newArray);
189
190 // For all indices in the ArrayType (i.e. the element count), write the next
191 // result from the new CallOp to the new array.
192 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
193 assert(allIndices); // follows from legal() check
194 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
195 for (ArrayAttr subIdx : allIndices.value()) {
196 genWrite(loc, newArray, subIdx, *newResults, rewriter);
197 newResults++;
198 }
199 } else {
200 newResults++;
201 }
202 }
203 // erase the original CallOp
204 rewriter.eraseOp(oldCall);
205
206 return newCall;
207}
208
213void processBlockArgs(Block &entryBlock, ConversionPatternRewriter &rewriter) {
214 OpBuilder::InsertionGuard guard(rewriter);
215 rewriter.setInsertionPointToStart(&entryBlock);
216
217 for (unsigned i = 0; i < entryBlock.getNumArguments();) {
218 Value oldV = entryBlock.getArgument(i);
219 if (ArrayType at = splittableArray(oldV.getType())) {
220 Location loc = oldV.getLoc();
221 // Generate `CreateArrayOp` and replace uses of the argument with it.
222 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
223 rewriter.replaceAllUsesWith(oldV, newArray);
224 // Remove the argument from the block
225 entryBlock.eraseArgument(i);
226 // For all indices in the ArrayType (i.e. the element count), generate a new block
227 // argument and a write of that argument to the new array.
228 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
229 assert(allIndices); // follows from legal() check
230 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
231 for (ArrayAttr subIdx : allIndices.value()) {
232 BlockArgument newArg = entryBlock.insertArgument(i, at.getElementType(), loc);
233 genWrite(loc, newArray, subIdx, newArg, rewriter);
234 ++i;
235 }
236 } else {
237 ++i;
238 }
239 }
240}
241
242inline ReadArrayOp
243genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
244 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
245 return rewriter.create<ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
246}
247
248// If the operand has ArrayType, add N reads from the array to the `newOperands` list otherwise add
249// the original operand to the list.
250void processInputOperand(
251 Location loc, Value operand, SmallVector<Value> &newOperands,
252 ConversionPatternRewriter &rewriter
253) {
254 if (ArrayType at = splittableArray(operand.getType())) {
255 std::optional<SmallVector<ArrayAttr>> indices = at.getSubelementIndices();
256 assert(indices.has_value() && "passed earlier hasStaticShape() check");
257 for (ArrayAttr index : indices.value()) {
258 newOperands.push_back(genRead(loc, operand, index, rewriter));
259 }
260 } else {
261 newOperands.push_back(operand);
262 }
263}
264
265// For each operand with ArrayType, add N reads from the array and use those N values instead.
266void processInputOperands(
267 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
268 ConversionPatternRewriter &rewriter
269) {
270 SmallVector<Value> newOperands;
271 for (Value v : operands) {
272 processInputOperand(op->getLoc(), v, newOperands, rewriter);
273 }
274 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
275 outputOpRef.assign(ValueRange(newOperands));
276 });
277}
278
279namespace {
280
281enum Direction {
283 SMALL_TO_LARGE,
285 LARGE_TO_SMALL,
286};
287
290template <Direction dir>
291inline void rewriteImpl(
292 ArrayAccessOpInterface op, ArrayType smallType, Value smallArr, Value largeArr,
293 ConversionPatternRewriter &rewriter
294) {
295 assert(smallType); // follows from legal() check
296 Location loc = op.getLoc();
297 MLIRContext *ctx = op.getContext();
298
299 ArrayAttr indexAsAttr = op.indexOperandsToAttributeArray();
300 assert(indexAsAttr); // follows from legal() check
301
302 // For all indices in the ArrayType (i.e. the element count), read from one array into the other
303 // (depending on direction flag).
304 std::optional<SmallVector<ArrayAttr>> subIndices = smallType.getSubelementIndices();
305 assert(subIndices); // follows from legal() check
306 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
307 for (ArrayAttr indexingTail : subIndices.value()) {
308 SmallVector<Attribute> joined;
309 joined.append(indexAsAttr.begin(), indexAsAttr.end());
310 joined.append(indexingTail.begin(), indexingTail.end());
311 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
312
313 if constexpr (dir == Direction::SMALL_TO_LARGE) {
314 auto init = genRead(loc, smallArr, indexingTail, rewriter);
315 genWrite(loc, largeArr, fullIndex, init, rewriter);
316 } else if constexpr (dir == Direction::LARGE_TO_SMALL) {
317 auto init = genRead(loc, largeArr, fullIndex, rewriter);
318 genWrite(loc, smallArr, indexingTail, init, rewriter);
319 }
320 }
321}
322
323} // namespace
324
325class SplitInsertArrayOp : public OpConversionPattern<InsertArrayOp> {
326public:
327 using OpConversionPattern<InsertArrayOp>::OpConversionPattern;
328
329 static bool legal(InsertArrayOp op) {
330 return !containsSplittableArrayType(op.getRvalue().getType());
331 }
332
333 LogicalResult match(InsertArrayOp op) const override { return failure(legal(op)); }
334
335 void
336 rewrite(InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
337 ArrayType at = splittableArray(op.getRvalue().getType());
338 rewriteImpl<SMALL_TO_LARGE>(
339 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
340 adaptor.getArrRef(), rewriter
341 );
342 rewriter.eraseOp(op);
343 }
344};
345
346class SplitExtractArrayOp : public OpConversionPattern<ExtractArrayOp> {
347public:
348 using OpConversionPattern<ExtractArrayOp>::OpConversionPattern;
349
350 static bool legal(ExtractArrayOp op) {
351 return !containsSplittableArrayType(op.getResult().getType());
352 }
353
354 LogicalResult match(ExtractArrayOp op) const override { return failure(legal(op)); }
355
356 void rewrite(ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
357 const override {
358 ArrayType at = splittableArray(op.getResult().getType());
359 // Generate `CreateArrayOp` in place of the current op.
360 auto newArray = rewriter.replaceOpWithNewOp<CreateArrayOp>(op, at);
361 rewriteImpl<LARGE_TO_SMALL>(
362 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
363 rewriter
364 );
365 }
366};
367
368class SplitInitFromCreateArrayOp : public OpConversionPattern<CreateArrayOp> {
369public:
370 using OpConversionPattern<CreateArrayOp>::OpConversionPattern;
371
372 static bool legal(CreateArrayOp op) { return op.getElements().empty(); }
373
374 LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); }
375
376 void
377 rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
378 // Remove elements from `op`
379 rewriter.modifyOpInPlace(op, [&op]() { op.getElementsMutable().clear(); });
380 // Generate an individual write for each initialization element
381 rewriter.setInsertionPointAfter(op);
382 Location loc = op.getLoc();
383 ArrayIndexGen idxGen = ArrayIndexGen::from(op.getType());
384 for (auto [i, init] : llvm::enumerate(adaptor.getElements())) {
385 // Convert the linear index 'i' into a multi-dim index
386 assert(std::cmp_less_equal(i, std::numeric_limits<int64_t>::max()));
387 std::optional<SmallVector<Value>> multiDimIdxVals =
388 idxGen.delinearize(static_cast<int64_t>(i), loc, rewriter);
389 // ASSERT: CreateArrayOp verifier ensures the number of elements provided matches the full
390 // linear array size so delinearization of `i` will not fail.
391 assert(multiDimIdxVals.has_value());
392 // Create the write
393 rewriter.create<WriteArrayOp>(loc, op.getResult(), ValueRange(*multiDimIdxVals), init);
394 }
395 }
396};
397
398class SplitArrayInFuncDefOp : public OpConversionPattern<FuncDefOp> {
399public:
400 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
401
402 inline static bool legal(FuncDefOp op) {
403 return !containsSplittableArrayType(op.getFunctionType());
404 }
405
406 LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); }
407
408 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
409 // Update in/out types of the function to replace arrays with scalars
410 FunctionType oldTy = op.getFunctionType();
411 SmallVector<Type> newInputs = splitArrayType(oldTy.getInputs());
412 SmallVector<Type> newOutputs = splitArrayType(oldTy.getResults());
413 FunctionType newTy =
414 FunctionType::get(oldTy.getContext(), TypeRange(newInputs), TypeRange(newOutputs));
415 if (newTy == oldTy) {
416 return; // nothing to change
417 }
418 rewriter.modifyOpInPlace(op, [&op, &newTy]() { op.setFunctionType(newTy); });
419
420 // If the function has a body, ensure the entry block arguments match the function inputs.
421 if (Region *body = op.getCallableRegion()) {
422 Block &entryBlock = body->front();
423 if (std::cmp_equal(entryBlock.getNumArguments(), newInputs.size())) {
424 return; // nothing to change
425 }
426 processBlockArgs(entryBlock, rewriter);
427 }
428 }
429};
430
431class SplitArrayInReturnOp : public OpConversionPattern<ReturnOp> {
432public:
433 using OpConversionPattern<ReturnOp>::OpConversionPattern;
434
435 inline static bool legal(ReturnOp op) {
436 return !containsSplittableArrayType(op.getOperands().getTypes());
437 }
438
439 LogicalResult match(ReturnOp op) const override { return failure(legal(op)); }
440
441 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
442 processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter);
443 }
444};
445
446class SplitArrayInCallOp : public OpConversionPattern<CallOp> {
447public:
448 using OpConversionPattern<CallOp>::OpConversionPattern;
449
450 inline static bool legal(CallOp op) {
451 return !containsSplittableArrayType(op.getArgOperands().getTypes()) &&
452 !containsSplittableArrayType(op.getResultTypes());
453 }
454
455 LogicalResult match(CallOp op) const override { return failure(legal(op)); }
456
457 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
458 assert(isNullOrEmpty(op.getMapOpGroupSizesAttr()) && "structs must be previously flattened");
459
460 // Create new CallOp with split results first so, then process its inputs to split types
461 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
462 processInputOperands(
463 newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter
464 );
465 }
466};
467
468class ReplaceKnownArrayLengthOp : public OpConversionPattern<ArrayLengthOp> {
469public:
470 using OpConversionPattern<ArrayLengthOp>::OpConversionPattern;
471
473 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx, ArrayType baseArrType) {
474 if (splittableArray(baseArrType)) {
475 llvm::APInt idxAP;
476 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
477 uint64_t idx64 = idxAP.getZExtValue();
478 assert(std::cmp_less_equal(idx64, std::numeric_limits<size_t>::max()));
479 Attribute dimSizeAttr = baseArrType.getDimensionSizes()[static_cast<size_t>(idx64)];
480 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
481 return idxAP;
482 }
483 }
484 }
485 return std::nullopt;
486 }
487
488 inline static bool legal(ArrayLengthOp op) {
489 // rewrite() can only work with constant dim size, i.e. must consider it legal otherwise
490 return !getDimSizeIfKnown(op.getDim(), op.getArrRefType()).has_value();
491 }
492
493 LogicalResult match(ArrayLengthOp op) const override { return failure(legal(op)); }
494
495 void
496 rewrite(ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
497 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
498 assert(arrTy); // must have array type per ODS spec of ArrayLengthOp
499 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
500 assert(len.has_value()); // follows from legal() check
501 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, llzk::fromAPInt(len.value()));
502 }
503};
504
506using FieldInfo = std::pair<StringAttr, Type>;
508using LocalFieldReplacementMap = DenseMap<ArrayAttr, FieldInfo>;
510using FieldReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalFieldReplacementMap>>;
511
512class SplitArrayInFieldDefOp : public OpConversionPattern<FieldDefOp> {
513 SymbolTableCollection &tables;
514 FieldReplacementMap &repMapRef;
515
516public:
517 SplitArrayInFieldDefOp(
518 MLIRContext *ctx, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap
519 )
520 : OpConversionPattern<FieldDefOp>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
521
522 inline static bool legal(FieldDefOp op) { return !containsSplittableArrayType(op.getType()); }
523
524 LogicalResult match(FieldDefOp op) const override { return failure(legal(op)); }
525
526 void rewrite(FieldDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
527 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
528 assert(inStruct);
529 LocalFieldReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()];
530
531 ArrayType arrTy = dyn_cast<ArrayType>(op.getType());
532 assert(arrTy); // follows from legal() check
533 auto subIdxs = arrTy.getSubelementIndices();
534 assert(subIdxs.has_value());
535 Type elemTy = arrTy.getElementType();
536
537 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
538 for (ArrayAttr idx : subIdxs.value()) {
539 // Create scalar version of the field
540 FieldDefOp newField =
541 rewriter.create<FieldDefOp>(op.getLoc(), op.getSymNameAttr(), elemTy, op.getColumn());
542 // Use SymbolTable to give it a unique name and store to the replacement map
543 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newField), elemTy);
544 }
545 rewriter.eraseOp(op);
546 }
547};
548
554template <
555 typename ImplClass, HasInterface<FieldRefOpInterface> FieldRefOpClass, typename GenHeaderType>
556class SplitArrayInFieldRefOp : public OpConversionPattern<FieldRefOpClass> {
557 SymbolTableCollection &tables;
558 const FieldReplacementMap &repMapRef;
559
560 // static check to ensure the functions are implemented in all subclasses
561 inline static void ensureImplementedAtCompile() {
562 static_assert(
563 sizeof(FieldRefOpClass) == 0, "SplitArrayInFieldRefOp not implemented for requested type."
564 );
565 }
566
567protected:
568 using OpAdaptor = typename FieldRefOpClass::Adaptor;
569
572 static GenHeaderType genHeader(FieldRefOpClass, ConversionPatternRewriter &) {
573 ensureImplementedAtCompile();
574 assert(false && "unreachable");
575 }
576
579 static void
580 forIndex(Location, GenHeaderType, ArrayAttr, FieldInfo, OpAdaptor, ConversionPatternRewriter &) {
581 ensureImplementedAtCompile();
582 assert(false && "unreachable");
583 }
584
585public:
586 SplitArrayInFieldRefOp(
587 MLIRContext *ctx, SymbolTableCollection &symTables, const FieldReplacementMap &fieldRepMap
588 )
589 : OpConversionPattern<FieldRefOpClass>(ctx), tables(symTables), repMapRef(fieldRepMap) {}
590
591 static bool legal(FieldRefOpClass) {
592 ensureImplementedAtCompile();
593 assert(false && "unreachable");
594 }
595
596 LogicalResult match(FieldRefOpClass op) const override { return failure(ImplClass::legal(op)); }
597
598 void rewrite(FieldRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
599 const override {
600 StructType tgtStructTy = llvm::cast<FieldRefOpInterface>(op.getOperation()).getStructType();
601 assert(tgtStructTy);
602 auto tgtStructDef = tgtStructTy.getDefinition(tables, op);
603 assert(succeeded(tgtStructDef));
604
605 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
606
607 const LocalFieldReplacementMap &idxToName =
608 repMapRef.at(tgtStructDef->get()).at(op.getFieldNameAttr().getAttr());
609 // Split the array field write into a series of read array + write scalar field
610 for (auto [idx, newField] : idxToName) {
611 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newField, adaptor, rewriter);
612 }
613 rewriter.eraseOp(op);
614 }
615};
616
617class SplitArrayInFieldWriteOp
618 : public SplitArrayInFieldRefOp<SplitArrayInFieldWriteOp, FieldWriteOp, void *> {
619public:
620 using SplitArrayInFieldRefOp<
621 SplitArrayInFieldWriteOp, FieldWriteOp, void *>::SplitArrayInFieldRefOp;
622
623 static bool legal(FieldWriteOp op) { return !containsSplittableArrayType(op.getVal().getType()); }
624
625 static void *genHeader(FieldWriteOp, ConversionPatternRewriter &) { return nullptr; }
626
627 static void forIndex(
628 Location loc, void *, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
629 ConversionPatternRewriter &rewriter
630 ) {
631 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
632 rewriter.create<FieldWriteOp>(
633 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newField.first), scalarRead
634 );
635 }
636};
637
638class SplitArrayInFieldReadOp
639 : public SplitArrayInFieldRefOp<SplitArrayInFieldReadOp, FieldReadOp, CreateArrayOp> {
640public:
641 using SplitArrayInFieldRefOp<
642 SplitArrayInFieldReadOp, FieldReadOp, CreateArrayOp>::SplitArrayInFieldRefOp;
643
644 static bool legal(FieldReadOp op) {
645 return !containsSplittableArrayType(op.getResult().getType());
646 }
647
648 static CreateArrayOp genHeader(FieldReadOp op, ConversionPatternRewriter &rewriter) {
649 CreateArrayOp newArray =
650 rewriter.create<CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
651 rewriter.replaceAllUsesWith(op, newArray);
652 return newArray;
653 }
654
655 static void forIndex(
656 Location loc, CreateArrayOp newArray, ArrayAttr idx, FieldInfo newField, OpAdaptor adaptor,
657 ConversionPatternRewriter &rewriter
658 ) {
659 FieldReadOp scalarRead =
660 rewriter.create<FieldReadOp>(loc, newField.second, adaptor.getComponent(), newField.first);
661 genWrite(loc, newArray, idx, scalarRead, rewriter);
662 }
663};
664
665LogicalResult
666step1(ModuleOp modOp, SymbolTableCollection &symTables, FieldReplacementMap &fieldRepMap) {
667 MLIRContext *ctx = modOp.getContext();
668
669 RewritePatternSet patterns(ctx);
670
671 patterns.add<SplitArrayInFieldDefOp>(ctx, symTables, fieldRepMap);
672
673 ConversionTarget target(*ctx);
674 target.addLegalDialect<
677 component::StructDialect, constrain::ConstrainDialect, arith::ArithDialect, scf::SCFDialect>(
678 );
679 target.addLegalOp<ModuleOp>();
680 target.addDynamicallyLegalOp<FieldDefOp>(SplitArrayInFieldDefOp::legal);
681
682 LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split array fields\n";);
683 return applyFullConversion(modOp, target, std::move(patterns));
684}
685
686LogicalResult
687step2(ModuleOp modOp, SymbolTableCollection &symTables, const FieldReplacementMap &fieldRepMap) {
688 MLIRContext *ctx = modOp.getContext();
689
690 RewritePatternSet patterns(ctx);
691 patterns.add<
692 // clang-format off
693 SplitInitFromCreateArrayOp,
694 SplitInsertArrayOp,
695 SplitExtractArrayOp,
696 SplitArrayInFuncDefOp,
697 SplitArrayInReturnOp,
698 SplitArrayInCallOp,
699 ReplaceKnownArrayLengthOp
700 // clang-format on
701 >(ctx);
702
703 patterns.add<
704 // clang-format off
705 SplitArrayInFieldWriteOp,
706 SplitArrayInFieldReadOp
707 // clang-format on
708 >(ctx, symTables, fieldRepMap);
709
710 ConversionTarget target(*ctx);
711 target.addLegalDialect<
715 scf::SCFDialect>();
716 target.addLegalOp<ModuleOp>();
717 target.addDynamicallyLegalOp<CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
718 target.addDynamicallyLegalOp<InsertArrayOp>(SplitInsertArrayOp::legal);
719 target.addDynamicallyLegalOp<ExtractArrayOp>(SplitExtractArrayOp::legal);
720 target.addDynamicallyLegalOp<FuncDefOp>(SplitArrayInFuncDefOp::legal);
721 target.addDynamicallyLegalOp<ReturnOp>(SplitArrayInReturnOp::legal);
722 target.addDynamicallyLegalOp<CallOp>(SplitArrayInCallOp::legal);
723 target.addDynamicallyLegalOp<ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
724 target.addDynamicallyLegalOp<FieldWriteOp>(SplitArrayInFieldWriteOp::legal);
725 target.addDynamicallyLegalOp<FieldReadOp>(SplitArrayInFieldReadOp::legal);
726
727 LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other array ops\n";);
728 return applyFullConversion(modOp, target, std::move(patterns));
729}
730
731LogicalResult splitArrayCreateInit(ModuleOp modOp) {
732 SymbolTableCollection symTables;
733 FieldReplacementMap fieldRepMap;
734
735 // This is divided into 2 steps to simplify the implementation for field-related ops. The issue is
736 // that the conversions for field read/write expect the mapping of array index to field name+type
737 // to already be populated for the referenced field (although this could be computed on demand if
738 // desired but it complicates the implementation a bit).
739 if (failed(step1(modOp, symTables, fieldRepMap))) {
740 return failure();
741 }
742 return step2(modOp, symTables, fieldRepMap);
743}
744
745class ArrayToScalarPass : public llzk::array::impl::ArrayToScalarPassBase<ArrayToScalarPass> {
746 void runOnOperation() override {
747 ModuleOp module = getOperation();
748 // Separate array initialization from creation by removing the initialization list from
749 // CreateArrayOp and inserting the corresponding WriteArrayOp following it.
750 if (failed(splitArrayCreateInit(module))) {
751 signalPassFailure();
752 return;
753 }
754 OpPassManager nestedPM(ModuleOp::getOperationName());
755 // Use SROA (Destructurable* interfaces) to split each array with linear size N into N arrays of
756 // size 1. This is necessary because the mem2reg pass cannot deal with indexing and splitting up
757 // memory, i.e. it can only convert scalar memory access into SSA values.
758 nestedPM.addPass(createSROA());
759 // The mem2reg pass converts all of the size 1 array allocation and access into SSA values.
760 nestedPM.addPass(createMem2Reg());
761 if (failed(runPipeline(nestedPM, module))) {
762 signalPassFailure();
763 return;
764 }
765 }
766};
767
768} // namespace
769
771 return std::make_unique<ArrayToScalarPass>();
772};
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Definition Ops.cpp:207
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::TypedValue<::mlir::IndexType > getDim()
Definition Ops.cpp.inc:140
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
Definition Ops.h.inc:158
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:114
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::MutableOperandRange getElementsMutable()
Definition Ops.cpp.inc:416
::mlir::Operation::operand_range getElements()
Definition Ops.cpp.inc:408
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.cpp.inc:438
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.cpp.inc:928
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
Definition Ops.cpp.inc:1184
::mlir::StringAttr getSymNameAttr()
Definition Ops.cpp.inc:576
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
::mlir::MutableOperandRange getArgOperandsMutable()
Definition Ops.cpp.inc:232
CallOpAdaptor Adaptor
Definition Ops.h.inc:180
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
Definition Ops.cpp.inc:540
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1130
::mlir::Region * getCallableRegion()
Returns the region on the current operation that is callable.
Definition Ops.h.inc:588
::mlir::Operation::operand_range getOperands()
Definition Ops.cpp.inc:1322
::mlir::MutableOperandRange getOperandsMutable()
Definition Ops.cpp.inc:1326
Restricts a template parameter to Op classes that implement the given OpInterface.
Definition Concepts.h:20
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
int64_t fromAPInt(llvm::APInt i)