LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
28#include "llzk/Util/Concepts.h"
29#include "llzk/Util/Debug.h"
34
35#include <mlir/Dialect/Affine/IR/AffineOps.h>
36#include <mlir/Dialect/Affine/LoopUtils.h>
37#include <mlir/Dialect/Arith/IR/Arith.h>
38#include <mlir/Dialect/SCF/IR/SCF.h>
39#include <mlir/Dialect/SCF/Utils/Utils.h>
40#include <mlir/Dialect/Utils/StaticValueUtils.h>
41#include <mlir/IR/Attributes.h>
42#include <mlir/IR/BuiltinAttributes.h>
43#include <mlir/IR/BuiltinOps.h>
44#include <mlir/IR/BuiltinTypes.h>
45#include <mlir/Interfaces/InferTypeOpInterface.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Support/LLVM.h>
48#include <mlir/Support/LogicalResult.h>
49#include <mlir/Transforms/DialectConversion.h>
50#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
51
52#include <llvm/ADT/APInt.h>
53#include <llvm/ADT/DenseMap.h>
54#include <llvm/ADT/DepthFirstIterator.h>
55#include <llvm/ADT/STLExtras.h>
56#include <llvm/ADT/SmallVector.h>
57#include <llvm/ADT/TypeSwitch.h>
58#include <llvm/Support/Debug.h>
59
60// Include the generated base pass class definitions.
61namespace llzk::polymorphic {
62#define GEN_PASS_DECL_FLATTENINGPASS
63#define GEN_PASS_DEF_FLATTENINGPASS
65} // namespace llzk::polymorphic
66
67#include "SharedImpl.h"
68
69#define DEBUG_TYPE "llzk-flatten"
70
71using namespace mlir;
72using namespace llzk;
73using namespace llzk::array;
74using namespace llzk::component;
75using namespace llzk::constrain;
76using namespace llzk::felt;
77using namespace llzk::function;
78using namespace llzk::polymorphic;
79using namespace llzk::polymorphic::detail;
80
81namespace {
82
83class ConversionTracker {
85 bool modified;
88 DenseMap<StructType, StructType> structInstantiations;
90 DenseMap<StructType, StructType> reverseInstantiations;
93 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
94
95public:
96 bool isModified() const { return modified; }
97 void resetModifiedFlag() { modified = false; }
98 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
99
100 void recordInstantiation(StructType oldType, StructType newType) {
101 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
103 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
104 if (forwardResult.second) {
105 // Insertion was successful
106 // ASSERT: The reverse map does not contain this mapping either
107 assert(!reverseInstantiations.contains(newType));
108 reverseInstantiations[newType] = oldType;
109 // Set the modified flag
110 modified = true;
111 } else {
112 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
113 assert(forwardResult.first->getSecond() == newType);
114 // ASSERT: The reverse mapping is already present as well
115 assert(reverseInstantiations.lookup(newType) == oldType);
116 }
117 assert(structInstantiations.size() == reverseInstantiations.size());
118 }
119
120 /// Return the instantiated type of the given StructType, if any.
121 std::optional<StructType> getInstantiation(StructType oldType) const {
122 auto cachedResult = structInstantiations.find(oldType);
123 if (cachedResult != structInstantiations.end()) {
124 return cachedResult->second;
125 }
126 return std::nullopt;
127 }
128
129
130 DenseSet<SymbolRefAttr> getInstantiatedStructNames() const {
131 DenseSet<SymbolRefAttr> instantiatedNames;
132 for (const auto &[origRemoteTy, _] : structInstantiations) {
133 instantiatedNames.insert(origRemoteTy.getNameRef());
134 }
135 return instantiatedNames;
137
138 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
139 auto res = delayedDiagnostics.find(newType);
140 if (res == delayedDiagnostics.end()) {
141 return;
142 }
143
144 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
145 for (Diagnostic &diag : res->second) {
146 // Update any notes referencing an UnknownLoc to use the CallOp location.
147 for (Diagnostic &note : diag.getNotes()) {
148 assert(note.getNotes().empty() && "notes cannot have notes attached");
149 if (llvm::isa<UnknownLoc>(note.getLocation())) {
150 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
151 }
152 }
153 // Report. Based on InFlightDiagnostic::report().
154 engine.emit(std::move(diag));
155 }
156 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
157 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
158 // `compute()` calls it will only be reported at one of those locations, not all.
159 delayedDiagnostics.erase(newType);
160 }
161
162 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
163 return delayedDiagnostics[newType];
165
168 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
169 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
170 // Check if `oTy` is a struct with a known instantiation to `nTy`
171 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
172 // Note: The values in `structInstantiations` must be no-parameter struct types
173 // so there is no need for recursive check, simple equality is sufficient.
174 if (this->structInstantiations.lookup(oldStructType) == nTy) {
175 return true;
176 }
177 }
178 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
179 // that instantiation (i.e., the parameterized version of the instantiated struct)
180 // is a more concrete unification of `oTy`.
181 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
182 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
183 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
184 return true;
185 }
186 }
187 }
188 return false;
189 };
190
191 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
192 return true;
193 }
194 LLVM_DEBUG(
195 llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
196 << " with new type " << newType
197 << " because it does not define a compatible and more concrete type.\n";
198 );
199 return false;
200 }
201
202 template <typename T, typename U>
203 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
204 return llvm::all_of(
205 llvm::zip_equal(oldTypes, newTypes), [this, &patName](std::tuple<Type, Type> oldThenNew) {
206 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
207 }
208 );
209 }
210};
211
214struct MatchFailureListener : public RewriterBase::Listener {
215 bool hadFailure = false;
216
217 ~MatchFailureListener() override {}
218
219 void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
220 hadFailure = true;
221
222 InFlightDiagnostic diag = emitError(loc);
223 reasonCallback(*diag.getUnderlyingDiagnostic());
224 diag.report();
225 }
226};
227
228static LogicalResult
229applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
230 bool currStepModified = false;
231 MatchFailureListener failureListener;
232 LogicalResult result = applyPatternsGreedily(
233 modOp->getRegion(0), std::move(patterns),
234 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
235 &currStepModified
236 );
237 tracker.updateModifiedFlag(currStepModified);
238 return failure(result.failed() || failureListener.hadFailure);
239}
240
241template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
242 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
243 return isConcreteType(tyAttr.getValue(), AllowStructParams);
244 }
245 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
246 return !isDynamic(intAttr);
247 }
248 return false;
249}
250
252
253static inline bool tableOffsetIsntSymbol(FieldReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
255}
256
259class StructCloner {
260 ConversionTracker &tracker_;
261 ModuleOp rootMod;
262 SymbolTableCollection symTables;
263
264 class MappedTypeConverter : public TypeConverter {
265 StructType origTy;
266 StructType newTy;
267 const DenseMap<Attribute, Attribute> &paramNameToValue;
268
269 inline Attribute convertIfPossible(Attribute a) const {
270 auto res = this->paramNameToValue.find(a);
271 return (res != this->paramNameToValue.end()) ? res->second : a;
272 }
273
274 public:
275 MappedTypeConverter(
276 StructType originalType, StructType newType,
278 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
279 )
280 : TypeConverter(), origTy(originalType), newTy(newType),
281 paramNameToValue(paramNameToInstantiatedValue) {
282
283 addConversion([](Type inputTy) { return inputTy; });
284
285 addConversion([this](StructType inputTy) {
286 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
287
288 // Check for replacement of the full type
289 if (inputTy == this->origTy) {
290 return this->newTy;
291 }
292 // Check for replacement of parameter symbol names with concrete values
293 if (ArrayAttr inputTyParams = inputTy.getParams()) {
294 SmallVector<Attribute> updated;
295 for (Attribute a : inputTyParams) {
296 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
297 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
298 } else {
299 updated.push_back(convertIfPossible(a));
300 }
301 }
302 return StructType::get(
303 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
304 );
305 }
306 // Otherwise, return the type unchanged
307 return inputTy;
308 });
309
310 addConversion([this](ArrayType inputTy) {
311 // Check for replacement of parameter symbol names with concrete values
312 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
313 if (!dimSizes.empty()) {
314 SmallVector<Attribute> updated;
315 for (Attribute a : dimSizes) {
316 updated.push_back(convertIfPossible(a));
317 }
318 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
319 }
320 // Otherwise, return the type unchanged
321 return inputTy;
322 });
323
324 addConversion([this](TypeVarType inputTy) -> Type {
325 // Check for replacement of parameter symbol name with a concrete type
326 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
327 Type convertedType = tyAttr.getValue();
328 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
329 // different struct references a parameter name from that other struct, not from the
330 // current struct so the reference would be invalid.
331 if (isConcreteType(convertedType)) {
332 return convertedType;
333 }
334 }
335 return inputTy;
336 });
337 }
338 };
339
340 template <typename Impl, typename Op, typename... HandledAttrs>
341 class SymbolUserHelper : public OpConversionPattern<Op> {
342 private:
343 const DenseMap<Attribute, Attribute> &paramNameToValue;
344
345 SymbolUserHelper(
346 TypeConverter &converter, MLIRContext *ctx, unsigned Benefit,
347 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
348 )
349 : OpConversionPattern<Op>(converter, ctx, Benefit),
350 paramNameToValue(paramNameToInstantiatedValue) {}
351
352 public:
353 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
354
355 virtual Attribute getNameAttr(Op) const = 0;
356
357 virtual LogicalResult handleDefaultRewrite(
358 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
359 ) const {
360 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
361 }
362
363 LogicalResult
364 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
365 auto res = this->paramNameToValue.find(getNameAttr(op));
366 if (res == this->paramNameToValue.end()) {
367 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] no instantiation for " << op << '\n');
368 return failure();
369 }
370 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
371 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
372
373 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
374 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
375 }))),
376 ...);
377
378 return TS.Default([&](Attribute a) {
379 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
380 });
381 }
382 friend Impl;
383 };
384
385 class ClonedStructConstReadOpPattern
386 : public SymbolUserHelper<
387 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
388 SmallVector<Diagnostic> &diagnostics;
389
390 using super =
391 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
392
393 public:
394 ClonedStructConstReadOpPattern(
395 TypeConverter &converter, MLIRContext *ctx,
396 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
397 SmallVector<Diagnostic> &instantiationDiagnostics
398 )
399 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
400 // instead of the GeneralTypeReplacePattern<ConstReadOp> from newGeneralRewritePatternSet().
401 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue),
402 diagnostics(instantiationDiagnostics) {}
403
404 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
405
406 LogicalResult handleRewrite(
407 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
408 ) const {
409 APInt attrValue = a.getValue();
410 Type origResTy = op.getType();
411 if (llvm::isa<FeltType>(origResTy)) {
413 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
414 );
415 return success();
416 }
417
418 if (llvm::isa<IndexType>(origResTy)) {
420 return success();
421 }
422
423 if (origResTy.isSignlessInteger(1)) {
424 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
425 if (attrValue.isZero()) {
426 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
427 return success();
428 }
429 if (!attrValue.isOne()) {
430 Location opLoc = op.getLoc();
431 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
432 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
433 if (getContext()->shouldPrintOpOnDiagnostic()) {
434 diag.attachNote(opLoc) << "see current operation: " << *op;
435 }
436 diag.attachNote(UnknownLoc::get(getContext()))
437 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \""
438 << sym << "\" for this call";
439 diagnostics.push_back(std::move(diag));
440 }
441 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
442 return success();
443 }
444 return op->emitOpError().append("unexpected result type ", origResTy);
445 }
446
447 LogicalResult handleRewrite(
448 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
449 ) const {
450 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
451 return success();
452 }
453 };
454
455 class ClonedStructFieldReadOpPattern
456 : public SymbolUserHelper<
457 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
458 using super =
459 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
460
461 public:
462 ClonedStructFieldReadOpPattern(
463 TypeConverter &converter, MLIRContext *ctx,
464 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
465 )
466 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
467 // instead of the GeneralTypeReplacePattern<FieldReadOp> from newGeneralRewritePatternSet().
468 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue) {}
469
470 Attribute getNameAttr(FieldReadOp op) const override {
471 return op.getTableOffset().value_or(nullptr);
472 }
473
474 template <typename Attr>
475 LogicalResult handleRewrite(
476 Attribute, FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
477 ) const {
478 rewriter.modifyOpInPlace(op, [&]() {
479 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
480 });
481
482 return success();
483 }
484
485 LogicalResult matchAndRewrite(
486 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
487 ) const override {
488 if (tableOffsetIsntSymbol(op)) {
489 return failure();
490 }
491
492 return super::matchAndRewrite(op, adaptor, rewriter);
493 }
494 };
495
496 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
497 // Find the StructDefOp for the original StructType
498 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.getDefinition(symTables, rootMod);
499 if (failed(r)) {
500 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
501 return failure(); // getDefinition() already emits a sufficient error message
502 }
503
504 StructDefOp origStruct = r->get();
505 StructType typeAtDef = origStruct.getType();
506 MLIRContext *ctx = origStruct.getContext();
507
508 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
509 DenseMap<Attribute, Attribute> paramNameToConcrete;
510 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
511 // where the original attribute from the current instantiation site was not concrete. This is
512 // used for generating the new struct name. See `BuildShortTypeString::from()`.
513 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
514 // Parameter list for the new StructDefOp containing the names that must be preserved because
515 // they were not assigned concrete values at the current instantiation site.
516 ArrayAttr reducedParamNameList = nullptr;
517 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
518 ArrayAttr reducedCallerParams = nullptr;
519 {
520 ArrayAttr paramNames = typeAtDef.getParams();
521
522 // pre-conditions
523 assert(!isNullOrEmpty(paramNames));
524 assert(paramNames.size() == typeAtCallerParams.size());
525
526 SmallVector<Attribute> remainingNames;
527 SmallVector<Attribute> nonConcreteParams;
528 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
529 Attribute next = typeAtCallerParams[i];
530 if (isConcreteAttr<false>(next)) {
531 paramNameToConcrete[paramNames[i]] = next;
532 attrsForInstantiatedNameSuffix.push_back(next);
533 } else {
534 remainingNames.push_back(paramNames[i]);
535 nonConcreteParams.push_back(next);
536 attrsForInstantiatedNameSuffix.push_back(nullptr);
537 }
538 }
539 // post-conditions
540 assert(remainingNames.size() == nonConcreteParams.size());
541 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
542 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
543
544 if (paramNameToConcrete.empty()) {
545 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
546 return failure();
547 }
548 if (!remainingNames.empty()) {
549 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
550 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
551 }
552 }
553
554 // Clone the original struct, apply the new name, and set the parameter list of the new struct
555 // to contain only those that did not have concrete instantiated values.
556 StructDefOp newStruct = origStruct.clone();
557 newStruct.setConstParamsAttr(reducedParamNameList);
558 newStruct.setSymName(
560 typeAtCaller.getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
561 )
562 );
563
564 // Insert 'newStruct' into the parent ModuleOp of the original StructDefOp. Use the
565 // `SymbolTable::insert()` function directly so that the name will be made unique.
566 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>(); // parent is ModuleOp per ODS
567 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
568 // Retrieve the new type AFTER inserting since the name may be appended to make it unique and
569 // use the remaining non-concrete parameters from the original type.
570 StructType newRemoteType = newStruct.getType(reducedCallerParams);
571 LLVM_DEBUG({
572 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
573 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
574 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
575 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
576 });
577
578 // Within the new struct, replace all references to the original StructType (i.e., the
579 // locally-parameterized version) with the new locally-parameterized StructType,
580 // and replace all uses of the removed struct parameters with the concrete values.
581 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
582 ConversionTarget target =
583 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
584 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
585 // Legal if it's not in the map of concrete attribute instantiations
586 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
587 });
588
589 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
590 patterns.add<ClonedStructConstReadOpPattern>(
591 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
592 );
593 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
594 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
595 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
596 return failure();
597 }
598 return newRemoteType;
599 }
600
601public:
602 StructCloner(ConversionTracker &tracker, ModuleOp root)
603 : tracker_(tracker), rootMod(root), symTables() {}
604
605 FailureOr<StructType> createInstantiatedClone(StructType orig) {
606 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
607 if (ArrayAttr params = orig.getParams()) {
608 return genClone(orig, params.getValue());
609 }
610 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
611 return failure();
612 }
613};
614
615class ParameterizedStructUseTypeConverter : public TypeConverter {
616 ConversionTracker &tracker_;
617 StructCloner cloner;
618
619public:
620 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
621 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
622
623 addConversion([](Type inputTy) { return inputTy; });
624
625 addConversion([this](StructType inputTy) -> StructType {
626 // First check for a cached entry
627 if (auto opt = tracker_.getInstantiation(inputTy)) {
628 return opt.value();
629 }
630
631 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
632 // done, return the original type to indicate that it's still legal (for this step at least).
633 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
634 if (failed(cloneRes)) {
635 return inputTy;
636 }
637 StructType newTy = cloneRes.value();
638 LLVM_DEBUG(
639 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
640 << " as " << newTy << '\n'
641 );
642 tracker_.recordInstantiation(inputTy, newTy);
643 return newTy;
644 });
645
646 addConversion([this](ArrayType inputTy) {
647 return inputTy.cloneWith(convertType(inputTy.getElementType()));
648 });
649 }
650};
651
652class CallStructFuncPattern : public OpConversionPattern<CallOp> {
653 ConversionTracker &tracker_;
654
655public:
656 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
657 // Must use higher benefit than CallOpClassReplacePattern so this pattern will be applied
658 // instead of the CallOpClassReplacePattern from newGeneralRewritePatternSet().
659 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/2), tracker_(tracker) {}
660
661 LogicalResult matchAndRewrite(
662 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
663 ) const override {
664 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
665
666 // Convert the result types of the CallOp
667 SmallVector<Type> newResultTypes;
668 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
669 return op->emitError("Could not convert Op result types.");
670 }
671 LLVM_DEBUG({
672 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
673 << debug::toStringList(newResultTypes) << '\n';
674 });
675
676 // Update the callee to reflect the new struct target if necessary. These checks are based on
677 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
678 // Instead they must come from the converted versions.
679 SymbolRefAttr calleeAttr = op.getCalleeAttr();
680 if (op.calleeIsStructCompute()) {
681 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
682 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
683 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
684 tracker_.reportDelayedDiagnostics(newStTy, op);
685 }
686 } else if (op.calleeIsStructConstrain()) {
687 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
688 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
689 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
690 }
691 }
692
693 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
695 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
696 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
697 );
698 (void)newOp; // tell compiler it's intentionally unused in release builds
699 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
700 return success();
701 }
702};
703
704// This one ensures FieldDefOp types are converted even if there are no reads/writes to them.
705class FieldDefOpPattern : public OpConversionPattern<FieldDefOp> {
706public:
707 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
708 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
709 // instead of the GeneralTypeReplacePattern<FieldDefOp> from newGeneralRewritePatternSet().
710 : OpConversionPattern<FieldDefOp>(converter, ctx, /*benefit=*/2) {}
711
712 LogicalResult matchAndRewrite(
713 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
714 ) const override {
715 LLVM_DEBUG(llvm::dbgs() << "[FieldDefOpPattern] FieldDefOp: " << op << '\n');
716
717 Type oldFieldType = op.getType();
718 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
719 if (oldFieldType == newFieldType) {
720 // nothing changed
721 return failure();
722 }
723 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.setType(newFieldType); });
724 return success();
725 }
726};
727
728LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
729 MLIRContext *ctx = modOp.getContext();
730 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
731 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
732 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
733 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
734 return applyPartialConversion(modOp, target, std::move(patterns));
735}
736
737} // namespace Step1_InstantiateStructs
738
739namespace Step2_Unroll {
740
741// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
742template <HasInterface<LoopLikeOpInterface> OpClass>
743class LoopUnrollPattern : public OpRewritePattern<OpClass> {
744public:
745 using OpRewritePattern<OpClass>::OpRewritePattern;
746
747 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
748 if (auto maybeConstant = getConstantTripCount(loopOp)) {
749 uint64_t tripCount = *maybeConstant;
750 if (tripCount == 0) {
751 rewriter.eraseOp(loopOp);
752 return success();
753 } else if (tripCount == 1) {
754 return loopOp.promoteIfSingleIteration(rewriter);
755 }
756 return loopUnrollByFactor(loopOp, tripCount);
757 }
758 return failure();
759 }
760
761private:
764 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
765 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
766 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
767 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
768 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
769 return std::nullopt;
770 }
771 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
772 }
773};
774
775LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
776 MLIRContext *ctx = modOp.getContext();
777 RewritePatternSet patterns(ctx);
778 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
779 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
780
781 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
782}
783} // namespace Step2_Unroll
784
786
787// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
788// version uses a basic loop instead of llvm::map_to_vector().
789std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
790 SmallVector<int64_t> res;
791 for (OpFoldResult ofr : ofrs) {
792 std::optional<int64_t> cv = getConstantIntValue(ofr);
793 if (!cv.has_value()) {
794 return std::nullopt;
795 }
796 res.push_back(cv.value());
797 }
798 return res;
799}
800
801struct AffineMapFolder {
802 struct Input {
803 OperandRangeRange mapOpGroups;
804 DenseI32ArrayAttr dimsPerGroup;
805 ArrayRef<Attribute> paramsOfStructTy;
806 };
807
808 struct Output {
809 SmallVector<SmallVector<Value>> mapOpGroups;
810 SmallVector<int32_t> dimsPerGroup;
811 SmallVector<Attribute> paramsOfStructTy;
812 };
813
814 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
815 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
816 return ValueRange(grp);
817 });
818 }
819
820 static LogicalResult
821 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
822 if (in.mapOpGroups.empty()) {
823 // No affine map operands so nothing to do
824 return failure();
825 }
826
827 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
828 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
829
830 size_t idx = 0; // index in `mapOpGroups`, i.e., the number of AffineMapAttr encountered
831 for (Attribute sizeAttr : in.paramsOfStructTy) {
832 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
833 ValueRange currMapOps = in.mapOpGroups[idx++];
834 LLVM_DEBUG(
835 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
836 << '\n'
837 );
838 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
839 LLVM_DEBUG(
840 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
841 << debug::toStringList(currMapOpsCast) << '\n'
842 );
843 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
844 SmallVector<Attribute> result;
845 bool hasPoison = false; // indicates divide by 0 or mod by <1
846 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
847 return rewriter.getIndexAttr(v);
848 });
849 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
850 if (hasPoison) {
851 // Diagnostic remark: could be removed for release builds if too noisy
852 op->emitRemark()
853 .append(
854 "Cannot fold affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
855 " due to divide by 0 or modulus with negative divisor"
856 )
857 .report();
858 return failure();
859 }
860 if (failed(foldResult)) {
861 // Diagnostic remark: could be removed for release builds if too noisy
862 op->emitRemark()
863 .append(
864 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " failed"
865 )
866 .report();
867 return failure();
868 }
869 if (result.size() != 1) {
870 // Diagnostic remark: could be removed for release builds if too noisy
871 op->emitRemark()
872 .append(
873 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
874 " produced ", result.size(), " results but expected 1"
875 )
876 .report();
877 return failure();
878 }
879 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
880 out.paramsOfStructTy.push_back(result[0]);
881 continue;
882 }
883 // If affine but not foldable, preserve the map ops
884 out.mapOpGroups.emplace_back(currMapOps);
885 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
886 }
887 // If not affine and foldable, preserve the original
888 out.paramsOfStructTy.push_back(sizeAttr);
889 }
890 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
891 assert(
892 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
893 "produced wrong number of dimensions"
894 );
895
896 return success();
897 }
898};
899
901class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
902 [[maybe_unused]]
903 ConversionTracker &tracker_;
904
905public:
906 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
907 : OpRewritePattern(ctx), tracker_(tracker) {}
908
909 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
910 ArrayType oldResultType = op.getType();
911
912 AffineMapFolder::Output out;
913 AffineMapFolder::Input in = {
914 op.getMapOperands(),
916 oldResultType.getDimensionSizes(),
917 };
918 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
919 return failure();
920 }
921
922 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
923 if (newResultType == oldResultType) {
924 // nothing changed
925 return failure();
926 }
927 // ASSERT: folding only preserves the original Attribute or converts affine to integer
928 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
929 LLVM_DEBUG(
930 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
931 << newResultType << " in \"" << op << "\"\n"
932 );
934 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
935 );
936 return success();
937 }
938};
939
941class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
942 ConversionTracker &tracker_;
943
944public:
945 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
946 : OpRewritePattern(ctx), tracker_(tracker) {}
947
948 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
949 if (!op.calleeIsStructCompute()) {
950 // this pattern only applies when the callee is "compute()" within a struct
951 return failure();
952 }
953 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
955 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
956 ArrayAttr params = oldRetTy.getParams();
957 if (isNullOrEmpty(params)) {
958 // nothing to do if the StructType is not parameterized
959 return failure();
960 }
961
962 AffineMapFolder::Output out;
963 AffineMapFolder::Input in = {
964 op.getMapOperands(),
966 params.getValue(),
967 };
968 if (!in.mapOpGroups.empty()) {
969 // If there are affine map operands, attempt to fold them to a constant.
970 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
971 return failure();
972 }
973 LLVM_DEBUG({
974 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
975 });
976 } else {
977 // If there are no affine map operands, attempt to refine the result type of the CallOp using
978 // the function argument types and the type of the target function.
979 auto callArgTypes = op.getArgOperands().getTypes();
980 if (callArgTypes.empty()) {
981 // no refinement possible if no function arguments
982 return failure();
983 }
984 SymbolTableCollection tables;
985 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
986 if (failed(lookupRes)) {
987 return failure();
988 }
989 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
990 return failure();
991 }
992 LLVM_DEBUG({
993 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
994 "result type params: "
995 << debug::toStringList(out.paramsOfStructTy) << '\n';
996 });
997 }
998
999 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
1000 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
1001 if (newRetTy == oldRetTy) {
1002 // nothing changed
1003 return failure();
1004 }
1005 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
1006 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
1007 // the conversion here is illegal it means there is a type conflict within the LLZK code that
1008 // prevents instantiation of the struct with the requested type.
1009 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
1010 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1011 diag.append(
1012 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1013 ", but found ", oldRetTy
1014 );
1015 });
1016 }
1017 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1019 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1020 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1021 );
1022 (void)newOp; // tell compiler it's intentionally unused in release builds
1023 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1024 return success();
1025 }
1026
1027private:
1030 inline LogicalResult instantiateViaTargetType(
1031 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1032 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1033 ) const {
1034 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1035 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1036 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1037 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1038
1039 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1040 // Nothing can change if everything is already concrete
1041 return failure();
1042 }
1043
1044 LLVM_DEBUG({
1045 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1046 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1047 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1048 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1049 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1050 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1051 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1052 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1053 });
1054
1055 UnificationMap unifications;
1056 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1057 (void)unifies; // tell compiler it's intentionally unused in builds without assertions
1058 assert(unifies && "should have been checked by verifiers");
1059
1060 LLVM_DEBUG({
1061 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1062 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1063 });
1064
1065 // Check for LHS SymRef (i.e., from the target function) that have RHS concrete Attributes (i.e.
1066 // from the call argument types) without any struct parameters (because the type with concrete
1067 // struct parameters will be used to instantiate the target struct rather than the fully
1068 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1069 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1070 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1071 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1072 [&unifications](std::tuple<Attribute, Attribute> p) {
1073 Attribute fromCall = std::get<1>(p);
1074 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1075 // non-parameterized concrete unification for the target struct parameter symbol.
1076 if (!isConcreteAttr<>(fromCall)) {
1077 Attribute fromTgt = std::get<0>(p);
1078 LLVM_DEBUG({
1079 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1080 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1081 });
1082 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1083 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1084 if (it != unifications.end()) {
1085 Attribute unifiedAttr = it->second;
1086 LLVM_DEBUG({
1087 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1088 });
1089 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1090 return unifiedAttr;
1091 }
1092 }
1093 }
1094 return fromCall;
1095 }
1096 );
1097
1098 out.paramsOfStructTy = newReturnStructParams;
1099 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1100 assert(out.mapOpGroups.empty() && "post-condition");
1101 assert(out.dimsPerGroup.empty() && "post-condition");
1102 return success();
1103 }
1104};
1105
1106LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1107 MLIRContext *ctx = modOp.getContext();
1108 RewritePatternSet patterns(ctx);
1109 patterns.add<
1110 InstantiateAtCreateArrayOp, // CreateArrayOp
1111 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1112 >(ctx, tracker);
1113
1114 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1115}
1116
1117} // namespace Step3_InstantiateAffineMaps
1118
1120
1122class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1123 ConversionTracker &tracker_;
1124
1125public:
1126 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1127 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1128
1129 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1130 Value createResult = op.getResult();
1131 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1132 assert(createResultType && "CreateArrayOp must produce ArrayType");
1133 Type oldResultElemType = createResultType.getElementType();
1134
1135 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1136 // element type is different.
1137 Type newResultElemType = nullptr;
1138 for (Operation *user : createResult.getUsers()) {
1139 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1140 if (writeOp.getArrRef() != createResult) {
1141 continue;
1142 }
1143 Type writeRValueType = writeOp.getRvalue().getType();
1144 if (writeRValueType == oldResultElemType) {
1145 continue;
1146 }
1147 if (newResultElemType && newResultElemType != writeRValueType) {
1148 LLVM_DEBUG(
1149 llvm::dbgs()
1150 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1151 << newResultElemType << " vs " << writeRValueType << '\n'
1152 );
1153 return failure();
1154 }
1155 newResultElemType = writeRValueType;
1156 }
1157 }
1158 if (!newResultElemType) {
1159 // no replacement type found
1160 return failure();
1161 }
1162 if (!tracker_.isLegalConversion(
1163 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1164 )) {
1165 return failure();
1166 }
1167 ArrayType newType = createResultType.cloneWith(newResultElemType);
1168 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1169 LLVM_DEBUG(
1170 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1171 );
1172 return success();
1173 }
1174};
1175
1176namespace {
1177
1178LogicalResult updateArrayElemFromArrAccessOp(
1179 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1180 PatternRewriter &rewriter
1181) {
1182 ArrayType oldArrType = op.getArrRefType();
1183 if (oldArrType.getElementType() == scalarElemTy) {
1184 return failure(); // no change needed
1185 }
1186 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1187 if (oldArrType == newArrType ||
1188 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1189 return failure();
1190 }
1191 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1192 LLVM_DEBUG(
1193 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1194 );
1195 return success();
1196}
1197
1198} // namespace
1199
1200class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1201 ConversionTracker &tracker_;
1202
1203public:
1204 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1205 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1206
1207 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1208 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1209 }
1210};
1211
1212class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1213 ConversionTracker &tracker_;
1214
1215public:
1216 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1217 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1218
1219 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1220 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1221 }
1222};
1223
1225class UpdateFieldDefTypeFromWrite final : public OpRewritePattern<FieldDefOp> {
1226 ConversionTracker &tracker_;
1227
1228public:
1229 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1230 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1231
1232 LogicalResult matchAndRewrite(FieldDefOp op, PatternRewriter &rewriter) const override {
1233 // Find all uses of the field symbol name within its parent struct.
1234 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(op);
1235 assert(succeeded(parentRes) && "FieldDefOp parent is always StructDefOp"); // per ODS def
1236
1237 // If the symbol is used by a FieldWriteOp with a different result type then change
1238 // the type of the FieldDefOp to match the FieldWriteOp result type.
1239 Type newType = nullptr;
1240 if (auto fieldUsers = llzk::getSymbolUses(op, parentRes.value())) {
1241 std::optional<Location> newTypeLoc = std::nullopt;
1242 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1243 if (FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1244 Type writeToType = writeOp.getVal().getType();
1245 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] checking " << writeOp << '\n');
1246 if (!newType) {
1247 // If a new type has not yet been discovered, store the new type.
1248 newType = writeToType;
1249 newTypeLoc = writeOp.getLoc();
1250 } else if (writeToType != newType) {
1251 // Typically, there will only be one write for each field of a struct but do not rely on
1252 // that assumption. If multiple writes with a different types A and B are found where
1253 // A->B is a legal conversion (i.e., more concrete unification), then it is safe to use
1254 // type B with the assumption that the write with type A will be updated by another
1255 // pattern to also use type B.
1256 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateFieldDefTypeFromWrite")) {
1257 if (tracker_.isLegalConversion(newType, writeToType, "UpdateFieldDefTypeFromWrite")) {
1258 // 'writeToType' is the more concrete type
1259 newType = writeToType;
1260 newTypeLoc = writeOp.getLoc();
1261 } else {
1262 // Give an error if the types are incompatible.
1263 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1264 diag.append(
1265 "Cannot update type of '", FieldDefOp::getOperationName(),
1266 "' because there are multiple '", FieldWriteOp::getOperationName(),
1267 "' with different value types"
1268 );
1269 if (newTypeLoc) {
1270 diag.attachNote(*newTypeLoc).append("type written here is ", newType);
1271 }
1272 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1273 });
1274 }
1275 }
1276 }
1277 }
1278 }
1279 }
1280 if (!newType || newType == op.getType()) {
1281 // nothing changed
1282 return failure();
1283 }
1284 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateFieldDefTypeFromWrite")) {
1285 return failure();
1286 }
1287 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1288 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] updated type of " << op << '\n');
1289 return success();
1290 }
1291};
1292
1293namespace {
1294
1295SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1296 SmallVector<std::unique_ptr<Region>> newRegions;
1297 for (Region &region : op->getRegions()) {
1298 auto newRegion = std::make_unique<Region>();
1299 newRegion->takeBody(region);
1300 newRegions.push_back(std::move(newRegion));
1301 }
1302 return newRegions;
1303}
1304
1305} // namespace
1306
1309class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1310 ConversionTracker &tracker_;
1311
1312public:
1313 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1314 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1315
1316 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1317 SmallVector<Type, 1> inferredResultTypes;
1318 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1319 LogicalResult result = retTypeFn.inferReturnTypes(
1320 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1321 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1322 );
1323 if (failed(result)) {
1324 return failure();
1325 }
1326 if (op->getResultTypes() == inferredResultTypes) {
1327 // nothing changed
1328 return failure();
1329 }
1330 if (!tracker_.areLegalConversions(
1331 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1332 )) {
1333 return failure();
1334 }
1335
1336 // Move nested region bodies and replace the original op with the updated types list.
1337 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1338 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1339 Operation *newOp = rewriter.create(
1340 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1341 op->getAttrs(), op->getSuccessors(), newRegions
1342 );
1343 rewriter.replaceOp(op, newOp);
1344 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1345 return success();
1346 }
1347};
1348
1350class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1351 ConversionTracker &tracker_;
1352
1353public:
1354 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1355 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1356
1357 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1358 Region &body = op.getFunctionBody();
1359 if (body.empty()) {
1360 return failure();
1361 }
1362 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1363 assert(retOp && "final op in body region must be return");
1364 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1365
1366 FunctionType oldFuncTy = op.getFunctionType();
1367 if (oldFuncTy.getResults() == tyFromReturnOp) {
1368 // nothing changed
1369 return failure();
1370 }
1371 if (!tracker_.areLegalConversions(
1372 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1373 )) {
1374 return failure();
1375 }
1376
1377 rewriter.modifyOpInPlace(op, [&]() {
1378 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1379 });
1380 LLVM_DEBUG(
1381 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
1382 << oldFuncTy << " to " << op.getFunctionType() << '\n'
1383 );
1384 return success();
1385 }
1386};
1387
1392class UpdateGlobalCallOpTypes final : public OpRewritePattern<CallOp> {
1393 ConversionTracker &tracker_;
1394
1395public:
1396 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1397 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1398
1399 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1400 SymbolTableCollection tables;
1401 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1402 if (failed(lookupRes)) {
1403 return failure();
1404 }
1405 FuncDefOp targetFunc = lookupRes->get();
1406 if (targetFunc.isInStruct()) {
1407 // this pattern only applies when the callee is NOT in a struct
1408 return failure();
1409 }
1410 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
1411 // nothing changed
1412 return failure();
1413 }
1414 if (!tracker_.areLegalConversions(
1415 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
1416 "UpdateGlobalCallOpTypes"
1417 )) {
1418 return failure();
1419 }
1420
1421 LLVM_DEBUG(llvm::dbgs() << "[UpdateGlobalCallOpTypes] replaced " << op);
1422 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
1423 (void)newOp; // tell compiler it's intentionally unused in release builds
1424 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1425 return success();
1426 }
1427};
1428
1429namespace {
1430
1431LogicalResult updateFieldRefValFromFieldDef(
1432 FieldRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
1433) {
1434 SymbolTableCollection tables;
1435 auto def = op.getFieldDefOp(tables);
1436 if (failed(def)) {
1437 return failure();
1438 }
1439 Type oldResultType = op.getVal().getType();
1440 Type newResultType = def->get().getType();
1441 if (oldResultType == newResultType ||
1442 !tracker.isLegalConversion(oldResultType, newResultType, "updateFieldRefValFromFieldDef")) {
1443 return failure();
1444 }
1445 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
1446 LLVM_DEBUG(
1447 llvm::dbgs() << "[updateFieldRefValFromFieldDef] updated value type in " << op << '\n'
1448 );
1449 return success();
1450}
1451
1452} // namespace
1453
1455class UpdateFieldReadValFromDef final : public OpRewritePattern<FieldReadOp> {
1456 ConversionTracker &tracker_;
1457
1458public:
1459 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1460 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1461
1462 LogicalResult matchAndRewrite(FieldReadOp op, PatternRewriter &rewriter) const override {
1463 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1464 }
1465};
1466
1468class UpdateFieldWriteValFromDef final : public OpRewritePattern<FieldWriteOp> {
1469 ConversionTracker &tracker_;
1470
1471public:
1472 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1473 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1474
1475 LogicalResult matchAndRewrite(FieldWriteOp op, PatternRewriter &rewriter) const override {
1476 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1477 }
1478};
1479
1480LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1481 MLIRContext *ctx = modOp.getContext();
1482 RewritePatternSet patterns(ctx);
1483 patterns.add<
1484 // Benefit of this one must be higher than rules that would propagate the type in the opposite
1485 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
1486 // benefit = 6
1487 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
1488 // benefit = 3
1489 UpdateGlobalCallOpTypes, // CallOp, targeting non-struct functions
1490 UpdateFuncTypeFromReturn, // FuncDefOp
1491 UpdateNewArrayElemFromWrite, // CreateArrayOp
1492 UpdateArrayElemFromArrRead, // ReadArrayOp
1493 UpdateArrayElemFromArrWrite, // WriteArrayOp
1494 UpdateFieldDefTypeFromWrite, // FieldDefOp
1495 UpdateFieldReadValFromDef, // FieldReadOp
1496 UpdateFieldWriteValFromDef // FieldWriteOp
1497 >(ctx, tracker);
1498
1499 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1500}
1501} // namespace Step4_PropagateTypes
1502
1503namespace Step5_Cleanup {
1504
1505class CleanupBase {
1506public:
1507 SymbolTableCollection tables;
1508
1509 CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph)
1510 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1511
1512protected:
1513 ModuleOp rootMod;
1514 const SymbolDefTree &defTree;
1515 const SymbolUseGraph &useGraph;
1516};
1517
1518struct FromKeepSet : public CleanupBase {
1519 using CleanupBase::CleanupBase;
1520
1524 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1525 // Initialize roots from the given StructDefOp instances
1526 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1527 // Add GlobalDefOp and "free functions" to the set of roots
1528 rootMod.walk([&roots](Operation *op) {
1529 if (global::GlobalDefOp gdef = llvm::dyn_cast<global::GlobalDefOp>(op)) {
1530 roots.insert(gdef);
1531 } else if (function::FuncDefOp fdef = llvm::dyn_cast<function::FuncDefOp>(op)) {
1532 if (!fdef.isInStruct()) {
1533 roots.insert(fdef);
1534 }
1535 }
1536 });
1537
1538 // Use a SymbolDefTree to find all Symbol defs reachable from one of the root nodes. Then
1539 // collect all Symbol uses reachable from those def nodes. These are the symbols that should
1540 // be preserved. All other symbol defs should be removed.
1541 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1542 for (size_t i = 0; i < roots.size(); ++i) { // iterate for safe insertion
1543 SymbolOpInterface keepRoot = roots[i];
1544 LLVM_DEBUG({ llvm::dbgs() << "[EraseUnreachable] root: " << keepRoot << '\n'; });
1545 const SymbolDefTreeNode *keepRootNode = defTree.lookupNode(keepRoot);
1546 assert(keepRootNode && "every struct def must be in the def tree");
1547 for (const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1548 LLVM_DEBUG({
1549 llvm::dbgs() << "[EraseUnreachable] can reach: " << reachableDefNode->getOp() << '\n';
1550 });
1551 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1552 // Use 'depth_first_ext()' to get all symbol uses reachable from the current Symbol def
1553 // node. There are no uses if the node is not in the graph. Within the loop that populates
1554 // 'depth_first_ext()', also check if the symbol is a StructDefOp and ensure it is in
1555 // 'roots' so the outer loop will ensure that all symbols reachable from it are preserved.
1556 if (const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1557 for (const SymbolUseGraphNode *usedSymbolNode :
1558 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1559 LLVM_DEBUG({
1560 llvm::dbgs() << "[EraseUnreachable] uses symbol: "
1561 << usedSymbolNode->getSymbolPath() << '\n';
1562 });
1563 // Ignore struct/template parameter symbols (before doing the lookup below because it
1564 // would fail anyway and then cause the "failed" case to be triggered unnecessarily).
1565 if (usedSymbolNode->isStructParam()) {
1566 continue;
1567 }
1568 // If `usedSymbolNode` references a StructDefOp, ensure it's considered in the roots.
1569 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1570 if (failed(lookupRes)) {
1571 LLVM_DEBUG(useGraph.dumpToDotFile());
1572 return failure();
1573 }
1574 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1575 if (lookupRes->viaInclude()) {
1576 continue;
1577 }
1578 if (StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1579 bool insertRes = roots.insert(asStruct);
1580 (void)insertRes; // tell compiler it's intentionally unused in release builds
1581 LLVM_DEBUG({
1582 if (insertRes) {
1583 llvm::dbgs() << "[EraseUnreachable] found another root: " << asStruct << '\n';
1584 }
1585 });
1586 }
1587 }
1588 }
1589 }
1590 }
1591 }
1592
1593 rootMod.walk([this, &symbolsToKeep](StructDefOp op) {
1594 const SymbolUseGraphNode *n = this->useGraph.lookupNode(op);
1595 assert(n);
1596 if (!symbolsToKeep.contains(n)) {
1597 LLVM_DEBUG(llvm::dbgs() << "[EraseUnreachable] removing: " << op.getSymName() << '\n');
1598 op.erase();
1599 }
1600
1601 return WalkResult::skip(); // StructDefOp cannot be nested
1602 });
1603
1604 return success();
1605 }
1606};
1607
1608struct FromEraseSet : public CleanupBase {
1609
1611 FromEraseSet(
1612 ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph,
1613 DenseSet<SymbolRefAttr> &&tryToErasePaths
1614 )
1615 : CleanupBase(root, symDefTree, symUseGraph) {
1616 // Convert the set of paths targeted for erasure into a set of the StructDefOp
1617 for (SymbolRefAttr path : tryToErasePaths) {
1618 Operation *lookupFrom = rootMod.getOperation();
1619 auto res = lookupSymbolIn<StructDefOp>(tables, path, lookupFrom, lookupFrom);
1620 assert(succeeded(res) && "inputs must be valid StructDefOp references");
1621 if (!res->viaInclude()) { // do not remove if it's from another source file
1622 tryToErase.insert(res->get());
1623 }
1624 }
1625 }
1626
1627 LogicalResult eraseUnusedStructs() {
1628 // Collect the subset of 'tryToErase' that has no remaining uses.
1629 for (StructDefOp sd : tryToErase) {
1630 collectSafeToErase(sd);
1631 }
1632 // The `visitedPlusSafetyResult` will contain FuncDefOp w/in the StructDefOp so just a single
1633 // loop to `dyn_cast` and `erase()` will cause `use-after-free` errors w/in the `dyn_cast`.
1634 // Instead, reduce the map to only those that should be erased and erase in a separate loop.
1635 for (auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1636 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1637 visitedPlusSafetyResult.erase(it);
1638 }
1639 }
1640 for (auto &[sym, _] : visitedPlusSafetyResult) {
1641 LLVM_DEBUG(llvm::dbgs() << "[EraseIfUnused] removing: " << sym.getNameAttr() << '\n');
1642 sym.erase();
1643 }
1644 return success();
1645 }
1646
1647 const DenseSet<StructDefOp> &getTryToEraseSet() const { return tryToErase; }
1648
1649private:
1651 DenseSet<StructDefOp> tryToErase;
1655 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1657 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1658
1661 bool collectSafeToErase(SymbolOpInterface check) {
1662 assert(check); // pre-condition
1663
1664 // If previously visited, return the safety result.
1665 auto visited = visitedPlusSafetyResult.find(check);
1666 if (visited != visitedPlusSafetyResult.end()) {
1667 return visited->second;
1668 }
1669
1670 // If it's a StructDefOp that is not in `tryToErase` then it cannot be erased.
1671 if (StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1672 if (!tryToErase.contains(sd)) {
1673 visitedPlusSafetyResult[check] = false;
1674 return false;
1675 }
1676 }
1677
1678 // Otherwise, temporarily mark as safe b/c a node cannot keep itself live (and this prevents
1679 // the recursion from getting stuck in an infinite loop).
1680 visitedPlusSafetyResult[check] = true;
1681
1682 // Check if it's safe according to both the def tree and use graph.
1683 // Note: every symbol must have a def node but module symbols may not have a use node.
1684 if (collectSafeToErase(defTree.lookupNode(check))) {
1685 auto useNode = useGraph.lookupNode(check);
1686 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1687 if (!useNode || collectSafeToErase(useNode)) {
1688 return true;
1689 }
1690 }
1691
1692 // Otherwise, revert the safety decision and return it.
1693 visitedPlusSafetyResult[check] = false;
1694 return false;
1695 }
1696
1698 bool collectSafeToErase(const SymbolDefTreeNode *check) {
1699 assert(check); // pre-condition
1700 if (const SymbolDefTreeNode *p = check->getParent()) {
1701 if (SymbolOpInterface checkOp = p->getOp()) { // safe if parent is root
1702 return collectSafeToErase(checkOp);
1703 }
1704 }
1705 return true;
1706 }
1707
1709 bool collectSafeToErase(const SymbolUseGraphNode *check) {
1710 assert(check); // pre-condition
1711 for (const SymbolUseGraphNode *p : check->predecessorIter()) {
1712 if (SymbolOpInterface checkOp = cachedLookup(p)) { // safe if via IncludeOp
1713 if (!collectSafeToErase(checkOp)) {
1714 return false;
1715 }
1716 }
1717 }
1718 return true;
1719 }
1720
1725 SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) {
1726 assert(node && "must provide a node"); // pre-condition
1727 // Check for cached result
1728 auto fromCache = lookupCache.find(node);
1729 if (fromCache != lookupCache.end()) {
1730 return fromCache->second;
1731 }
1732 // Otherwise, perform lookup and cache
1733 auto lookupRes = node->lookupSymbol(tables);
1734 assert(succeeded(lookupRes) && "graph contains node with invalid path");
1735 assert(lookupRes->get() != nullptr && "lookup must return an Operation");
1736 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1737 // NOTE: The SymbolUseGraph does contain nodes for struct parameters which cannot cast to
1738 // SymbolOpInterface. However, those will always be leaf nodes in the SymbolUseGraph and
1739 // therefore will not be traversed by this analysis so directly casting is fine.
1740 SymbolOpInterface actualRes =
1741 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1742 // Cache and return
1743 lookupCache[node] = actualRes;
1744 assert((!actualRes == lookupRes->viaInclude()) && "not found iff included"); // post-condition
1745 return actualRes;
1746 }
1747};
1748
1749} // namespace Step5_Cleanup
1750
1751class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<FlatteningPass> {
1752
1753 void runOnOperation() override {
1754 ModuleOp modOp = getOperation();
1755 if (failed(runOn(modOp))) {
1756 LLVM_DEBUG({
1757 // If the pass failed, dump the current IR.
1758 llvm::dbgs() << "=====================================================================\n";
1759 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << '\n';
1760 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1761 llvm::dbgs() << "=====================================================================\n";
1762 });
1763 signalPassFailure();
1764 }
1765 }
1766
1767 inline LogicalResult runOn(ModuleOp modOp) {
1768 // If the cleanup mode is set to remove anything not reachable from the "Main" struct, do an
1769 // initial pass to remove things that are not reachable (as an optimization) because creating
1770 // an instantiated version of a struct will not cause something to become reachable that was
1771 // not already reachable in parameterized form.
1772 if (cleanupMode == StructCleanupMode::MainAsRoot) {
1773 if (failed(eraseUnreachableFromMainStruct(modOp))) {
1774 return failure();
1775 }
1776 }
1777
1778 {
1779 // Preliminary step: remove empty parameter lists from structs
1780 OpPassManager nestedPM(ModuleOp::getOperationName());
1781 nestedPM.addPass(createEmptyParamListRemoval());
1782 if (failed(runPipeline(nestedPM, modOp))) {
1783 return failure();
1784 }
1785 }
1786
1787 ConversionTracker tracker;
1788 unsigned loopCount = 0;
1789 do {
1790 ++loopCount;
1791 if (loopCount > iterationLimit) {
1792 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit
1793 << " iterations!\n";
1794 return failure();
1795 }
1796 tracker.resetModifiedFlag();
1797
1798 // Find calls to "compute()" that return a parameterized struct and replace it to call a
1799 // flattened version of the struct that has parameters replaced with the constant values.
1800 // Create the necessary instantiated/flattened struct in the same location as the original.
1801 if (failed(Step1_InstantiateStructs::run(modOp, tracker))) {
1802 llvm::errs() << DEBUG_TYPE << " failed while replacing concrete-parameter struct types\n";
1803 return failure();
1804 }
1805
1806 // Unroll loops with known iterations.
1807 if (failed(Step2_Unroll::run(modOp, tracker))) {
1808 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
1809 return failure();
1810 }
1811
1812 // Instantiate affine_map parameters of StructType and ArrayType.
1813 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
1814 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
1815 return failure();
1816 }
1817
1818 // Propagate updated types using the semantics of various ops.
1819 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
1820 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
1821 return failure();
1822 }
1823
1824 LLVM_DEBUG(if (tracker.isModified()) {
1825 llvm::dbgs() << "=====================================================================\n";
1826 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << '\n';
1827 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1828 llvm::dbgs() << "=====================================================================\n";
1829 });
1830 } while (tracker.isModified());
1831
1832 // Perform cleanup according to the 'cleanupMode' option.
1833 switch (cleanupMode) {
1834 case StructCleanupMode::MainAsRoot:
1835 return eraseUnreachableFromMainStruct(modOp, false);
1836 case StructCleanupMode::ConcreteAsRoot:
1837 return eraseUnreachableFromConcreteStructs(modOp);
1838 case StructCleanupMode::Preimage:
1839 return erasePreimageOfInstantiations(modOp, tracker);
1840 case StructCleanupMode::Disabled:
1841 return success();
1842 }
1843 llvm_unreachable("switch cases cover all options");
1844 }
1845
1846 // Erase parameterized structs that were replaced with concrete instantiations.
1847 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod, const ConversionTracker &tracker) {
1848 // TODO: The names from getInstantiatedStructNames() are NOT guaranteed to be paths from the
1849 // "top root" and they also do not indicate a root module so there could be ambiguity. This is a
1850 // broader problem in the FlatteningPass itself so let's just assume, for now, that these are
1851 // paths from the "top root". See [LLZK-286].
1852 Step5_Cleanup::FromEraseSet cleaner(
1853 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
1854 tracker.getInstantiatedStructNames()
1855 );
1856 LogicalResult res = cleaner.eraseUnusedStructs();
1857 if (succeeded(res)) {
1858 // Warn about any structs that were instantiated but still have uses elsewhere.
1859 const SymbolUseGraph *useGraph = nullptr;
1860 rootMod->walk([this, &cleaner, &useGraph](StructDefOp op) {
1861 if (cleaner.getTryToEraseSet().contains(op)) {
1862 // If needed, rebuild use graph to reflect deletions.
1863 if (!useGraph) {
1864 useGraph = &getAnalysis<SymbolUseGraph>();
1865 }
1866 // If the op has any users, report the warning.
1867 if (useGraph->lookupNode(op)->hasPredecessor()) {
1868 op.emitWarning("Parameterized struct still has uses!").report();
1869 }
1870 }
1871 return WalkResult::skip(); // StructDefOp cannot be nested
1872 });
1873 }
1874 return res;
1875 }
1876
1877 LogicalResult eraseUnreachableFromConcreteStructs(ModuleOp rootMod) {
1878 SmallVector<StructDefOp> roots;
1879 rootMod.walk([&roots](StructDefOp op) {
1880 // Note: no need to check if the ConstParamsAttr is empty since `EmptyParamRemovalPass`
1881 // ran earlier.
1882 if (!op.hasConstParamsAttr()) {
1883 roots.push_back(op);
1884 }
1885 return WalkResult::skip(); // StructDefOp cannot be nested
1886 });
1887
1888 Step5_Cleanup::FromKeepSet cleaner(
1889 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1890 );
1891 return cleaner.eraseUnreachableFrom(roots);
1892 }
1893
1894 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod, bool emitWarning = true) {
1895 Step5_Cleanup::FromKeepSet cleaner(
1896 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1897 );
1898 StructDefOp main =
1899 cleaner.tables.getSymbolTable(rootMod).lookup<StructDefOp>(COMPONENT_NAME_MAIN);
1900 if (emitWarning && !main) {
1901 // Emit warning if there is no "Main" because all structs may be removed (only structs that
1902 // are reachable from a global def or free function will be preserved since those constructs
1903 // are not candidate for removal in this pass).
1904 rootMod.emitWarning()
1905 .append(
1906 "using option '", cleanupMode.getArgStr(), '=',
1907 stringifyStructCleanupMode(StructCleanupMode::MainAsRoot), "' with no \"",
1908 COMPONENT_NAME_MAIN, "\" struct may remove all structs!"
1909 )
1910 .report();
1911 }
1912 return cleaner.eraseUnreachableFrom(
1913 main ? ArrayRef<StructDefOp> {main} : ArrayRef<StructDefOp> {}
1914 );
1915 }
1916};
1917
1918} // namespace
1919
1921 return std::make_unique<FlatteningPass>();
1922};
#define DEBUG_TYPE
#define DEBUG_TYPE
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
const SymbolDefTreeNode * getParent() const
Returns the parent node in the tree. The root node will return nullptr.
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool hasPredecessor() const
Return true if this node has any predecessors.
llvm::iterator_range< iterator > predecessorIter() const
Range over predecessor nodes.
Builds a graph structure representing the relationships between symbols and their uses.
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:923
::mlir::TypedValue<::mlir::Type > getRvalue()
Definition Ops.h.inc:1075
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:502
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:332
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:905
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.h.inc:712
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:599
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:900
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1152
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1600
bool hasConstParamsAttr()
Return false iff getConstParamsAttr() returns nullptr
Definition Ops.h.inc:1279
void setConstParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:1204
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::ArrayAttr getParams() const
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:47
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:761
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:784
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:755
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:272
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Returns the argument types of this function.
Definition Ops.h.inc:749
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:370
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:947
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:784
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:781
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:971
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:896
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:443
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:149
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:240
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:269
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:126
std::unique_ptr< mlir::Pass > createFlatteningPass()
std::unique_ptr< mlir::Pass > createEmptyParamListRemoval()
::llvm::StringRef stringifyStructCleanupMode(StructCleanupMode val)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:225
bool isConcreteType(Type type, bool allowStructParams)
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
Definition Constants.h:23
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:255
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:259
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
bool isDynamic(IntegerAttr intAttr)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
int64_t fromAPInt(const llvm::APInt &i)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)