LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
28#include "llzk/Util/Concepts.h"
29#include "llzk/Util/Debug.h"
34
35#include <mlir/Dialect/Affine/IR/AffineOps.h>
36#include <mlir/Dialect/Affine/LoopUtils.h>
37#include <mlir/Dialect/Arith/IR/Arith.h>
38#include <mlir/Dialect/SCF/IR/SCF.h>
39#include <mlir/Dialect/SCF/Utils/Utils.h>
40#include <mlir/Dialect/Utils/StaticValueUtils.h>
41#include <mlir/IR/Attributes.h>
42#include <mlir/IR/BuiltinAttributes.h>
43#include <mlir/IR/BuiltinOps.h>
44#include <mlir/IR/BuiltinTypes.h>
45#include <mlir/Interfaces/InferTypeOpInterface.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Support/LLVM.h>
48#include <mlir/Support/LogicalResult.h>
49#include <mlir/Transforms/DialectConversion.h>
50#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
51
52#include <llvm/ADT/APInt.h>
53#include <llvm/ADT/DenseMap.h>
54#include <llvm/ADT/DepthFirstIterator.h>
55#include <llvm/ADT/STLExtras.h>
56#include <llvm/ADT/SmallVector.h>
57#include <llvm/ADT/TypeSwitch.h>
58#include <llvm/Support/Debug.h>
59
60// Include the generated base pass class definitions.
61namespace llzk::polymorphic {
62#define GEN_PASS_DECL_FLATTENINGPASS
63#define GEN_PASS_DEF_FLATTENINGPASS
65} // namespace llzk::polymorphic
66
67#include "SharedImpl.h"
68
69#define DEBUG_TYPE "llzk-flatten"
70
71using namespace mlir;
72using namespace llzk;
73using namespace llzk::array;
74using namespace llzk::component;
75using namespace llzk::constrain;
76using namespace llzk::felt;
77using namespace llzk::function;
78using namespace llzk::polymorphic;
79using namespace llzk::polymorphic::detail;
80
81namespace {
82
83class ConversionTracker {
85 bool modified;
86 /// Maps original remote (i.e., use site) type to new remote type.
87 /// Note: The keys are always parameterized StructType and the values are no-parameter StructType.
88 DenseMap<StructType, StructType> structInstantiations;
90 DenseMap<StructType, StructType> reverseInstantiations;
93 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
94
95public:
96 bool isModified() const { return modified; }
97 void resetModifiedFlag() { modified = false; }
98 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
99
100 void recordInstantiation(StructType oldType, StructType newType) {
101 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
102
103 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
104 if (forwardResult.second) {
105 // Insertion was successful
106 // ASSERT: The reverse map does not contain this mapping either
107 assert(!reverseInstantiations.contains(newType));
108 reverseInstantiations[newType] = oldType;
109 // Set the modified flag
110 modified = true;
111 } else {
112 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
113 assert(forwardResult.first->getSecond() == newType);
114 // ASSERT: The reverse mapping is already present as well
115 assert(reverseInstantiations.lookup(newType) == oldType);
116 }
117 assert(structInstantiations.size() == reverseInstantiations.size());
119
121 std::optional<StructType> getInstantiation(StructType oldType) const {
122 auto cachedResult = structInstantiations.find(oldType);
123 if (cachedResult != structInstantiations.end()) {
124 return cachedResult->second;
125 }
126 return std::nullopt;
127 }
130 DenseSet<SymbolRefAttr> getInstantiatedStructNames() const {
131 DenseSet<SymbolRefAttr> instantiatedNames;
132 for (const auto &[origRemoteTy, _] : structInstantiations) {
133 instantiatedNames.insert(origRemoteTy.getNameRef());
134 }
135 return instantiatedNames;
136 }
137
138 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
139 auto res = delayedDiagnostics.find(newType);
140 if (res == delayedDiagnostics.end()) {
141 return;
142 }
143
144 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
145 for (Diagnostic &diag : res->second) {
146 // Update any notes referencing an UnknownLoc to use the CallOp location.
147 for (Diagnostic &note : diag.getNotes()) {
148 assert(note.getNotes().empty() && "notes cannot have notes attached");
149 if (llvm::isa<UnknownLoc>(note.getLocation())) {
150 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
151 }
152 }
153 // Report. Based on InFlightDiagnostic::report().
154 engine.emit(std::move(diag));
156 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
157 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
158 // `compute()` calls it will only be reported at one of those locations, not all.
159 delayedDiagnostics.erase(newType);
160 }
161
162 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
163 return delayedDiagnostics[newType];
164 }
165
168 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
169 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
170 // Check if `oTy` is a struct with a known instantiation to `nTy`
171 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
172 // Note: The values in `structInstantiations` must be no-parameter struct types
173 // so there is no need for recursive check, simple equality is sufficient.
174 if (this->structInstantiations.lookup(oldStructType) == nTy) {
175 return true;
176 }
177 }
178 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
179 // that instantiation (i.e., the parameterized version of the instantiated struct)
180 // is a more concrete unification of `oTy`.
181 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
182 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
183 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
184 return true;
185 }
186 }
187 }
188 return false;
189 };
190
191 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
192 return true;
193 }
194 LLVM_DEBUG(llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
195 << " with new type " << newType
196 << " because it does not define a compatible and more concrete type.\n";
197 );
198 return false;
199 }
200
201 template <typename T, typename U>
202 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
203 return llvm::all_of(
204 llvm::zip_equal(oldTypes, newTypes),
205 [this, &patName](std::tuple<Type, Type> oldThenNew) {
206 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
207 }
208 );
209 }
210};
211
214struct MatchFailureListener : public RewriterBase::Listener {
215 bool hadFailure = false;
216
217 ~MatchFailureListener() override {}
218
219 LogicalResult
220 notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
221 hadFailure = true;
222
223 InFlightDiagnostic diag = emitError(loc);
224 reasonCallback(*diag.getUnderlyingDiagnostic());
225 return diag; // implicitly calls `diag.report()`
226 }
227};
228
229static LogicalResult
230applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
231 bool currStepModified = false;
232 MatchFailureListener failureListener;
233 LogicalResult result = applyPatternsAndFoldGreedily(
234 modOp->getRegion(0), std::move(patterns),
235 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener}, &currStepModified
236 );
237 tracker.updateModifiedFlag(currStepModified);
238 return failure(result.failed() || failureListener.hadFailure);
239}
240
241template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
242 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
243 return isConcreteType(tyAttr.getValue(), AllowStructParams);
244 }
245 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
246 return !isDynamic(intAttr);
247 }
248 return false;
249}
250
252
253static inline bool tableOffsetIsntSymbol(FieldReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
255}
256
259class StructCloner {
260 ConversionTracker &tracker_;
261 ModuleOp rootMod;
262 SymbolTableCollection symTables;
263
264 class MappedTypeConverter : public TypeConverter {
265 StructType origTy;
266 StructType newTy;
267 const DenseMap<Attribute, Attribute> &paramNameToValue;
268
269 inline Attribute convertIfPossible(Attribute a) const {
270 auto res = this->paramNameToValue.find(a);
271 return (res != this->paramNameToValue.end()) ? res->second : a;
272 }
273
274 public:
275 MappedTypeConverter(
276 StructType originalType, StructType newType,
278 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
279 )
280 : TypeConverter(), origTy(originalType), newTy(newType),
281 paramNameToValue(paramNameToInstantiatedValue) {
282
283 addConversion([](Type inputTy) { return inputTy; });
284
285 addConversion([this](StructType inputTy) {
286 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
287
288 // Check for replacement of the full type
289 if (inputTy == this->origTy) {
290 return this->newTy;
291 }
292 // Check for replacement of parameter symbol names with concrete values
293 if (ArrayAttr inputTyParams = inputTy.getParams()) {
294 SmallVector<Attribute> updated;
295 for (Attribute a : inputTyParams) {
296 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
297 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
298 } else {
299 updated.push_back(convertIfPossible(a));
300 }
301 }
302 return StructType::get(
303 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
304 );
305 }
306 // Otherwise, return the type unchanged
307 return inputTy;
308 });
309
310 addConversion([this](ArrayType inputTy) {
311 // Check for replacement of parameter symbol names with concrete values
312 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
313 if (!dimSizes.empty()) {
314 SmallVector<Attribute> updated;
315 for (Attribute a : dimSizes) {
316 updated.push_back(convertIfPossible(a));
317 }
318 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
319 }
320 // Otherwise, return the type unchanged
321 return inputTy;
322 });
323
324 addConversion([this](TypeVarType inputTy) -> Type {
325 // Check for replacement of parameter symbol name with a concrete type
326 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
327 Type convertedType = tyAttr.getValue();
328 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
329 // different struct references a parameter name from that other struct, not from the
330 // current struct so the reference would be invalid.
331 if (isConcreteType(convertedType)) {
332 return convertedType;
333 }
334 }
335 return inputTy;
336 });
337 }
338 };
339
340 template <typename Impl, typename Op, typename... HandledAttrs>
341 class SymbolUserHelper : public OpConversionPattern<Op> {
342 private:
343 const DenseMap<Attribute, Attribute> &paramNameToValue;
344
345 SymbolUserHelper(
346 TypeConverter &converter, MLIRContext *ctx, unsigned Benefit,
347 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
348 )
349 : OpConversionPattern<Op>(converter, ctx, Benefit),
350 paramNameToValue(paramNameToInstantiatedValue) {}
351
352 public:
353 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
354
355 virtual Attribute getNameAttr(Op) const = 0;
356
357 virtual LogicalResult handleDefaultRewrite(
358 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
359 ) const {
360 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
361 }
362
363 LogicalResult
364 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
365 auto res = this->paramNameToValue.find(getNameAttr(op));
366 if (res == this->paramNameToValue.end()) {
367 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] no instantiation for " << op << '\n');
368 return failure();
369 }
370 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
371 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
372
373 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
374 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
375 }))),
376 ...);
377
378 return TS.Default([&](Attribute a) {
379 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
380 });
381 }
382 friend Impl;
383 };
384
385 class ClonedStructConstReadOpPattern
386 : public SymbolUserHelper<
387 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
388 SmallVector<Diagnostic> &diagnostics;
389
390 using super =
391 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
392
393 public:
394 ClonedStructConstReadOpPattern(
395 TypeConverter &converter, MLIRContext *ctx,
396 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
397 SmallVector<Diagnostic> &instantiationDiagnostics
398 )
399 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
400 // instead of the GeneralTypeReplacePattern<ConstReadOp> from newGeneralRewritePatternSet().
401 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue),
402 diagnostics(instantiationDiagnostics) {}
403
404 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
405
406 LogicalResult handleRewrite(
407 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
408 ) const {
409 APInt attrValue = a.getValue();
410 Type origResTy = op.getType();
411 if (llvm::isa<FeltType>(origResTy)) {
413 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
414 );
415 return success();
416 }
417
418 if (llvm::isa<IndexType>(origResTy)) {
420 return success();
421 }
422
423 if (origResTy.isSignlessInteger(1)) {
424 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
425 if (attrValue.isZero()) {
426 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
427 return success();
428 }
429 if (!attrValue.isOne()) {
430 Location opLoc = op.getLoc();
431 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
432 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
433 if (getContext()->shouldPrintOpOnDiagnostic()) {
434 diag.attachNote(opLoc) << "see current operation: " << *op;
435 }
436 diag.attachNote(UnknownLoc::get(getContext()))
437 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \""
438 << sym << "\" for this call";
439 diagnostics.push_back(std::move(diag));
440 }
441 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
442 return success();
443 }
444 return op->emitOpError().append("unexpected result type ", origResTy);
445 }
446
447 LogicalResult handleRewrite(
448 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
449 ) const {
450 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
451 return success();
452 }
453 };
454
455 class ClonedStructFieldReadOpPattern
456 : public SymbolUserHelper<
457 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
458 using super =
459 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
460
461 public:
462 ClonedStructFieldReadOpPattern(
463 TypeConverter &converter, MLIRContext *ctx,
464 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
465 )
466 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
467 // instead of the GeneralTypeReplacePattern<FieldReadOp> from newGeneralRewritePatternSet().
468 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue) {}
469
470 Attribute getNameAttr(FieldReadOp op) const override {
471 return op.getTableOffset().value_or(nullptr);
472 }
473
474 template <typename Attr>
475 LogicalResult handleRewrite(
476 Attribute, FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
477 ) const {
478 rewriter.modifyOpInPlace(op, [&]() {
479 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
480 });
481
482 return success();
483 }
484
485 LogicalResult matchAndRewrite(
486 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
487 ) const override {
488 if (tableOffsetIsntSymbol(op)) {
489 return failure();
490 }
491
492 return super::matchAndRewrite(op, adaptor, rewriter);
493 }
494 };
495
496 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
497 // Find the StructDefOp for the original StructType
498 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.getDefinition(symTables, rootMod);
499 if (failed(r)) {
500 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
501 return failure(); // getDefinition() already emits a sufficient error message
502 }
503
504 StructDefOp origStruct = r->get();
505 StructType typeAtDef = origStruct.getType();
506 MLIRContext *ctx = origStruct.getContext();
507
508 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
509 DenseMap<Attribute, Attribute> paramNameToConcrete;
510 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
511 // where the original attribute from the current instantiation site was not concrete. This is
512 // used for generating the new struct name. See `BuildShortTypeString::from()`.
513 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
514 // Parameter list for the new StructDefOp containing the names that must be preserved because
515 // they were not assigned concrete values at the current instantiation site.
516 ArrayAttr reducedParamNameList = nullptr;
517 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
518 ArrayAttr reducedCallerParams = nullptr;
519 {
520 ArrayAttr paramNames = typeAtDef.getParams();
521
522 // pre-conditions
523 assert(!isNullOrEmpty(paramNames));
524 assert(paramNames.size() == typeAtCallerParams.size());
525
526 SmallVector<Attribute> remainingNames;
527 SmallVector<Attribute> nonConcreteParams;
528 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
529 Attribute next = typeAtCallerParams[i];
530 if (isConcreteAttr<false>(next)) {
531 paramNameToConcrete[paramNames[i]] = next;
532 attrsForInstantiatedNameSuffix.push_back(next);
533 } else {
534 remainingNames.push_back(paramNames[i]);
535 nonConcreteParams.push_back(next);
536 attrsForInstantiatedNameSuffix.push_back(nullptr);
537 }
538 }
539 // post-conditions
540 assert(remainingNames.size() == nonConcreteParams.size());
541 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
542 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
543
544 if (paramNameToConcrete.empty()) {
545 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
546 return failure();
547 }
548 if (!remainingNames.empty()) {
549 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
550 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
551 }
552 }
553
554 // Clone the original struct, apply the new name, and set the parameter list of the new struct
555 // to contain only those that did not have concrete instantiated values.
556 StructDefOp newStruct = origStruct.clone();
557 newStruct.setConstParamsAttr(reducedParamNameList);
559 typeAtCaller.getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
560 ));
561
562 // Insert 'newStruct' into the parent ModuleOp of the original StructDefOp. Use the
563 // `SymbolTable::insert()` function directly so that the name will be made unique.
564 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>(); // parent is ModuleOp per ODS
565 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
566 // Retrieve the new type AFTER inserting since the name may be appended to make it unique and
567 // use the remaining non-concrete parameters from the original type.
568 StructType newRemoteType = newStruct.getType(reducedCallerParams);
569 LLVM_DEBUG({
570 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
571 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
572 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
573 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
574 });
575
576 // Within the new struct, replace all references to the original StructType (i.e., the
577 // locally-parameterized version) with the new locally-parameterized StructType,
578 // and replace all uses of the removed struct parameters with the concrete values.
579 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
580 ConversionTarget target =
581 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
582 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
583 // Legal if it's not in the map of concrete attribute instantiations
584 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
585 });
586
587 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
588 patterns.add<ClonedStructConstReadOpPattern>(
589 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
590 );
591 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
592 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
593 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
594 return failure();
595 }
596 return newRemoteType;
597 }
598
599public:
600 StructCloner(ConversionTracker &tracker, ModuleOp root)
601 : tracker_(tracker), rootMod(root), symTables() {}
602
603 FailureOr<StructType> createInstantiatedClone(StructType orig) {
604 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
605 if (ArrayAttr params = orig.getParams()) {
606 return genClone(orig, params.getValue());
607 }
608 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
609 return failure();
610 }
611};
612
613class ParameterizedStructUseTypeConverter : public TypeConverter {
614 ConversionTracker &tracker_;
615 StructCloner cloner;
616
617public:
618 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
619 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
620
621 addConversion([](Type inputTy) { return inputTy; });
622
623 addConversion([this](StructType inputTy) -> StructType {
624 // First check for a cached entry
625 if (auto opt = tracker_.getInstantiation(inputTy)) {
626 return opt.value();
627 }
628
629 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
630 // done, return the original type to indicate that it's still legal (for this step at least).
631 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
632 if (failed(cloneRes)) {
633 return inputTy;
634 }
635 StructType newTy = cloneRes.value();
636 LLVM_DEBUG(
637 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
638 << " as " << newTy << '\n'
639 );
640 tracker_.recordInstantiation(inputTy, newTy);
641 return newTy;
642 });
643
644 addConversion([this](ArrayType inputTy) {
645 return inputTy.cloneWith(convertType(inputTy.getElementType()));
646 });
647 }
648};
649
650class CallStructFuncPattern : public OpConversionPattern<CallOp> {
651 ConversionTracker &tracker_;
652
653public:
654 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
655 // Must use higher benefit than CallOpClassReplacePattern so this pattern will be applied
656 // instead of the CallOpClassReplacePattern from newGeneralRewritePatternSet().
657 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/2), tracker_(tracker) {}
658
659 LogicalResult matchAndRewrite(CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter)
660 const override {
661 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
662
663 // Convert the result types of the CallOp
664 SmallVector<Type> newResultTypes;
665 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
666 return op->emitError("Could not convert Op result types.");
667 }
668 LLVM_DEBUG({
669 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
670 << debug::toStringList(newResultTypes) << '\n';
671 });
672
673 // Update the callee to reflect the new struct target if necessary. These checks are based on
674 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
675 // Instead they must come from the converted versions.
676 SymbolRefAttr calleeAttr = op.getCalleeAttr();
677 if (op.calleeIsStructCompute()) {
678 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
679 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
680 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
681 tracker_.reportDelayedDiagnostics(newStTy, op);
682 }
683 } else if (op.calleeIsStructConstrain()) {
684 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
685 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
686 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
687 }
688 }
689
690 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
692 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
693 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
694 );
695 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
696 return success();
697 }
698};
699
700// This one ensures FieldDefOp types are converted even if there are no reads/writes to them.
701class FieldDefOpPattern : public OpConversionPattern<FieldDefOp> {
702public:
703 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
704 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
705 // instead of the GeneralTypeReplacePattern<FieldDefOp> from newGeneralRewritePatternSet().
706 : OpConversionPattern<FieldDefOp>(converter, ctx, /*benefit=*/2) {}
707
708 LogicalResult matchAndRewrite(
709 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
710 ) const override {
711 LLVM_DEBUG(llvm::dbgs() << "[FieldDefOpPattern] FieldDefOp: " << op << '\n');
712
713 Type oldFieldType = op.getType();
714 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
715 if (oldFieldType == newFieldType) {
716 // nothing changed
717 return failure();
718 }
719 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.setType(newFieldType); });
720 return success();
721 }
722};
723
724LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
725 MLIRContext *ctx = modOp.getContext();
726 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
727 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
728 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
729 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
730 return applyPartialConversion(modOp, target, std::move(patterns));
731}
732
733} // namespace Step1_InstantiateStructs
734
735namespace Step2_Unroll {
736
737// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
738template <HasInterface<LoopLikeOpInterface> OpClass>
739class LoopUnrollPattern : public OpRewritePattern<OpClass> {
740public:
741 using OpRewritePattern<OpClass>::OpRewritePattern;
742
743 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
744 if (auto maybeConstant = getConstantTripCount(loopOp)) {
745 uint64_t tripCount = *maybeConstant;
746 if (tripCount == 0) {
747 rewriter.eraseOp(loopOp);
748 return success();
749 } else if (tripCount == 1) {
750 return loopOp.promoteIfSingleIteration(rewriter);
751 }
752 return loopUnrollByFactor(loopOp, tripCount);
753 }
754 return failure();
755 }
756
757private:
760 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
761 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
762 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
763 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
764 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
765 return std::nullopt;
766 }
767 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
768 }
769};
770
771LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
772 MLIRContext *ctx = modOp.getContext();
773 RewritePatternSet patterns(ctx);
774 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
775 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
776
777 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
778}
779} // namespace Step2_Unroll
780
782
783// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
784// version uses a basic loop instead of llvm::map_to_vector().
785std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
786 SmallVector<int64_t> res;
787 for (OpFoldResult ofr : ofrs) {
788 std::optional<int64_t> cv = getConstantIntValue(ofr);
789 if (!cv.has_value()) {
790 return std::nullopt;
791 }
792 res.push_back(cv.value());
793 }
794 return res;
795}
796
797struct AffineMapFolder {
798 struct Input {
799 OperandRangeRange mapOpGroups;
800 DenseI32ArrayAttr dimsPerGroup;
801 ArrayRef<Attribute> paramsOfStructTy;
802 };
803
804 struct Output {
805 SmallVector<SmallVector<Value>> mapOpGroups;
806 SmallVector<int32_t> dimsPerGroup;
807 SmallVector<Attribute> paramsOfStructTy;
808 };
809
810 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
811 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
812 return ValueRange(grp);
813 });
814 }
815
816 static LogicalResult
817 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
818 if (in.mapOpGroups.empty()) {
819 // No affine map operands so nothing to do
820 return failure();
821 }
822
823 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
824 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
825
826 size_t idx = 0; // index in `mapOpGroups`, i.e., the number of AffineMapAttr encountered
827 for (Attribute sizeAttr : in.paramsOfStructTy) {
828 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
829 ValueRange currMapOps = in.mapOpGroups[idx++];
830 LLVM_DEBUG(
831 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
832 << '\n'
833 );
834 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
835 LLVM_DEBUG(
836 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
837 << debug::toStringList(currMapOpsCast) << '\n'
838 );
839 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
840 SmallVector<Attribute> result;
841 bool hasPoison = false; // indicates divide by 0 or mod by <1
842 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
843 return rewriter.getIndexAttr(v);
844 });
845 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
846 if (hasPoison) {
847 LLVM_DEBUG(op->emitRemark().append(
848 "Cannot fold affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
849 " due to divide by 0 or modulus with negative divisor"
850 ));
851 return failure();
852 }
853 if (failed(foldResult)) {
854 LLVM_DEBUG(op->emitRemark().append(
855 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " failed"
856 ));
857 return failure();
858 }
859 if (result.size() != 1) {
860 LLVM_DEBUG(op->emitRemark().append(
861 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " produced ",
862 result.size(), " results but expected 1"
863 ));
864 return failure();
865 }
866 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
867 out.paramsOfStructTy.push_back(result[0]);
868 continue;
869 }
870 // If affine but not foldable, preserve the map ops
871 out.mapOpGroups.emplace_back(currMapOps);
872 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
873 }
874 // If not affine and foldable, preserve the original
875 out.paramsOfStructTy.push_back(sizeAttr);
876 }
877 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
878 assert(
879 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
880 "produced wrong number of dimensions"
881 );
882
883 return success();
884 }
885};
886
888class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
889 ConversionTracker &tracker_;
890
891public:
892 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
893 : OpRewritePattern(ctx), tracker_(tracker) {}
894
895 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
896 ArrayType oldResultType = op.getType();
897
898 AffineMapFolder::Output out;
899 AffineMapFolder::Input in = {
900 op.getMapOperands(),
902 oldResultType.getDimensionSizes(),
903 };
904 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
905 return failure();
906 }
907
908 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
909 if (newResultType == oldResultType) {
910 // nothing changed
911 return failure();
912 }
913 // ASSERT: folding only preserves the original Attribute or converts affine to integer
914 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
915 LLVM_DEBUG(
916 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
917 << newResultType << " in \"" << op << "\"\n"
918 );
920 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
921 );
922 return success();
923 }
924};
925
927class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
928 ConversionTracker &tracker_;
929
930public:
931 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
932 : OpRewritePattern(ctx), tracker_(tracker) {}
933
934 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
935 if (!op.calleeIsStructCompute()) {
936 // this pattern only applies when the callee is "compute()" within a struct
937 return failure();
938 }
939 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
941 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
942 ArrayAttr params = oldRetTy.getParams();
943 if (isNullOrEmpty(params)) {
944 // nothing to do if the StructType is not parameterized
945 return failure();
946 }
947
948 AffineMapFolder::Output out;
949 AffineMapFolder::Input in = {
950 op.getMapOperands(),
952 params.getValue(),
953 };
954 if (!in.mapOpGroups.empty()) {
955 // If there are affine map operands, attempt to fold them to a constant.
956 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
957 return failure();
958 }
959 LLVM_DEBUG({
960 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
961 });
962 } else {
963 // If there are no affine map operands, attempt to refine the result type of the CallOp using
964 // the function argument types and the type of the target function.
965 auto callArgTypes = op.getArgOperands().getTypes();
966 if (callArgTypes.empty()) {
967 // no refinement possible if no function arguments
968 return failure();
969 }
970 SymbolTableCollection tables;
971 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
972 if (failed(lookupRes)) {
973 return failure();
974 }
975 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
976 return failure();
977 }
978 LLVM_DEBUG({
979 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
980 "result type params: "
981 << debug::toStringList(out.paramsOfStructTy) << '\n';
982 });
983 }
984
985 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
986 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
987 if (newRetTy == oldRetTy) {
988 // nothing changed
989 return failure();
990 }
991 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
992 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
993 // the conversion here is illegal it means there is a type conflict within the LLZK code that
994 // prevents instantiation of the struct with the requested type.
995 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
996 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
997 diag.append(
998 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
999 ", but found ", oldRetTy
1000 );
1001 });
1002 }
1003 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1005 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1006 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1007 );
1008 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1009 return success();
1010 }
1011
1012private:
1015 inline LogicalResult instantiateViaTargetType(
1016 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1017 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1018 ) const {
1019 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1020 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1021 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1022 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1023
1024 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1025 // Nothing can change if everything is already concrete
1026 return failure();
1027 }
1028
1029 LLVM_DEBUG({
1030 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1031 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1032 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1033 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1034 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1035 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1036 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1037 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1038 });
1039
1040 UnificationMap unifications;
1041 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1042 assert(unifies && "should have been checked by verifiers");
1043
1044 LLVM_DEBUG({
1045 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1046 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1047 });
1048
1049 // Check for LHS SymRef (i.e., from the target function) that have RHS concrete Attributes (i.e.
1050 // from the call argument types) without any struct parameters (because the type with concrete
1051 // struct parameters will be used to instantiate the target struct rather than the fully
1052 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1053 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1054 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1055 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1056 [&unifications](std::tuple<Attribute, Attribute> p) {
1057 Attribute fromCall = std::get<1>(p);
1058 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1059 // non-parameterized concrete unification for the target struct parameter symbol.
1060 if (!isConcreteAttr<>(fromCall)) {
1061 Attribute fromTgt = std::get<0>(p);
1062 LLVM_DEBUG({
1063 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1064 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1065 });
1066 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1067 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1068 if (it != unifications.end()) {
1069 Attribute unifiedAttr = it->second;
1070 LLVM_DEBUG({
1071 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1072 });
1073 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1074 return unifiedAttr;
1075 }
1076 }
1077 }
1078 return fromCall;
1079 }
1080 );
1081
1082 out.paramsOfStructTy = newReturnStructParams;
1083 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1084 assert(out.mapOpGroups.empty() && "post-condition");
1085 assert(out.dimsPerGroup.empty() && "post-condition");
1086 return success();
1087 }
1088};
1089
1090LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1091 MLIRContext *ctx = modOp.getContext();
1092 RewritePatternSet patterns(ctx);
1093 patterns.add<
1094 InstantiateAtCreateArrayOp, // CreateArrayOp
1095 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1096 >(ctx, tracker);
1097
1098 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1099}
1100
1101} // namespace Step3_InstantiateAffineMaps
1102
1104
1106class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1107 ConversionTracker &tracker_;
1108
1109public:
1110 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1111 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1112
1113 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1114 Value createResult = op.getResult();
1115 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1116 assert(createResultType && "CreateArrayOp must produce ArrayType");
1117 Type oldResultElemType = createResultType.getElementType();
1118
1119 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1120 // element type is different.
1121 Type newResultElemType = nullptr;
1122 for (Operation *user : createResult.getUsers()) {
1123 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1124 if (writeOp.getArrRef() != createResult) {
1125 continue;
1126 }
1127 Type writeRValueType = writeOp.getRvalue().getType();
1128 if (writeRValueType == oldResultElemType) {
1129 continue;
1130 }
1131 if (newResultElemType && newResultElemType != writeRValueType) {
1132 LLVM_DEBUG(
1133 llvm::dbgs()
1134 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1135 << newResultElemType << " vs " << writeRValueType << '\n'
1136 );
1137 return failure();
1138 }
1139 newResultElemType = writeRValueType;
1140 }
1141 }
1142 if (!newResultElemType) {
1143 // no replacement type found
1144 return failure();
1145 }
1146 if (!tracker_.isLegalConversion(
1147 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1148 )) {
1149 return failure();
1150 }
1151 ArrayType newType = createResultType.cloneWith(newResultElemType);
1152 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1153 LLVM_DEBUG(
1154 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1155 );
1156 return success();
1157 }
1158};
1159
1160namespace {
1161
1162LogicalResult updateArrayElemFromArrAccessOp(
1163 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1164 PatternRewriter &rewriter
1165) {
1166 ArrayType oldArrType = op.getArrRefType();
1167 if (oldArrType.getElementType() == scalarElemTy) {
1168 return failure(); // no change needed
1169 }
1170 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1171 if (oldArrType == newArrType ||
1172 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1173 return failure();
1174 }
1175 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1176 LLVM_DEBUG(
1177 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1178 );
1179 return success();
1180}
1181
1182} // namespace
1183
1184class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1185 ConversionTracker &tracker_;
1186
1187public:
1188 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1189 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1190
1191 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1192 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1193 }
1194};
1195
1196class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1197 ConversionTracker &tracker_;
1198
1199public:
1200 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1201 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1202
1203 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1204 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1205 }
1206};
1207
1209class UpdateFieldDefTypeFromWrite final : public OpRewritePattern<FieldDefOp> {
1210 ConversionTracker &tracker_;
1211
1212public:
1213 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1214 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1215
1216 LogicalResult matchAndRewrite(FieldDefOp op, PatternRewriter &rewriter) const override {
1217 // Find all uses of the field symbol name within its parent struct.
1218 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(op);
1219 assert(succeeded(parentRes) && "FieldDefOp parent is always StructDefOp"); // per ODS def
1220
1221 // If the symbol is used by a FieldWriteOp with a different result type then change
1222 // the type of the FieldDefOp to match the FieldWriteOp result type.
1223 Type newType = nullptr;
1224 if (auto fieldUsers = llzk::getSymbolUses(op, parentRes.value())) {
1225 std::optional<Location> newTypeLoc = std::nullopt;
1226 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1227 if (FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1228 Type writeToType = writeOp.getVal().getType();
1229 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] checking " << writeOp << '\n');
1230 if (!newType) {
1231 // If a new type has not yet been discovered, store the new type.
1232 newType = writeToType;
1233 newTypeLoc = writeOp.getLoc();
1234 } else if (writeToType != newType) {
1235 // Typically, there will only be one write for each field of a struct but do not rely on
1236 // that assumption. If multiple writes with a different types A and B are found where
1237 // A->B is a legal conversion (i.e., more concrete unification), then it is safe to use
1238 // type B with the assumption that the write with type A will be updated by another
1239 // pattern to also use type B.
1240 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateFieldDefTypeFromWrite")) {
1241 if (tracker_.isLegalConversion(newType, writeToType, "UpdateFieldDefTypeFromWrite")) {
1242 // 'writeToType' is the more concrete type
1243 newType = writeToType;
1244 newTypeLoc = writeOp.getLoc();
1245 } else {
1246 // Give an error if the types are incompatible.
1247 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1248 diag.append(
1249 "Cannot update type of '", FieldDefOp::getOperationName(),
1250 "' because there are multiple '", FieldWriteOp::getOperationName(),
1251 "' with different value types"
1252 );
1253 if (newTypeLoc) {
1254 diag.attachNote(*newTypeLoc).append("type written here is ", newType);
1255 }
1256 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1257 });
1258 }
1259 }
1260 }
1261 }
1262 }
1263 }
1264 if (!newType || newType == op.getType()) {
1265 // nothing changed
1266 return failure();
1267 }
1268 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateFieldDefTypeFromWrite")) {
1269 return failure();
1270 }
1271 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1272 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] updated type of " << op << '\n');
1273 return success();
1274 }
1275};
1276
1277namespace {
1278
1279SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1280 SmallVector<std::unique_ptr<Region>> newRegions;
1281 for (Region &region : op->getRegions()) {
1282 auto newRegion = std::make_unique<Region>();
1283 newRegion->takeBody(region);
1284 newRegions.push_back(std::move(newRegion));
1285 }
1286 return newRegions;
1287}
1288
1289} // namespace
1290
1293class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1294 ConversionTracker &tracker_;
1295
1296public:
1297 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1298 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1299
1300 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1301 SmallVector<Type, 1> inferredResultTypes;
1302 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1303 LogicalResult result = retTypeFn.inferReturnTypes(
1304 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1305 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1306 );
1307 if (failed(result)) {
1308 return failure();
1309 }
1310 if (op->getResultTypes() == inferredResultTypes) {
1311 // nothing changed
1312 return failure();
1313 }
1314 if (!tracker_.areLegalConversions(
1315 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1316 )) {
1317 return failure();
1318 }
1319
1320 // Move nested region bodies and replace the original op with the updated types list.
1321 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1322 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1323 Operation *newOp = rewriter.create(
1324 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1325 op->getAttrs(), op->getSuccessors(), newRegions
1326 );
1327 rewriter.replaceOp(op, newOp);
1328 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1329 return success();
1330 }
1331};
1332
1334class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1335 ConversionTracker &tracker_;
1336
1337public:
1338 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1339 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1340
1341 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1342 Region &body = op.getFunctionBody();
1343 if (body.empty()) {
1344 return failure();
1345 }
1346 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1347 assert(retOp && "final op in body region must be return");
1348 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1349
1350 FunctionType oldFuncTy = op.getFunctionType();
1351 if (oldFuncTy.getResults() == tyFromReturnOp) {
1352 // nothing changed
1353 return failure();
1354 }
1355 if (!tracker_.areLegalConversions(
1356 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1357 )) {
1358 return failure();
1359 }
1360
1361 rewriter.modifyOpInPlace(op, [&]() {
1362 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1363 });
1364 LLVM_DEBUG(
1365 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
1366 << oldFuncTy << " to " << op.getFunctionType() << '\n'
1367 );
1368 return success();
1369 }
1370};
1371
1376class UpdateGlobalCallOpTypes final : public OpRewritePattern<CallOp> {
1377 ConversionTracker &tracker_;
1378
1379public:
1380 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1381 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1382
1383 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1384 SymbolTableCollection tables;
1385 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1386 if (failed(lookupRes)) {
1387 return failure();
1388 }
1389 FuncDefOp targetFunc = lookupRes->get();
1390 if (targetFunc.isInStruct()) {
1391 // this pattern only applies when the callee is NOT in a struct
1392 return failure();
1393 }
1394 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
1395 // nothing changed
1396 return failure();
1397 }
1398 if (!tracker_.areLegalConversions(
1399 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
1400 "UpdateGlobalCallOpTypes"
1401 )) {
1402 return failure();
1403 }
1404
1405 LLVM_DEBUG(llvm::dbgs() << "[UpdateGlobalCallOpTypes] replaced " << op);
1406 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
1407 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1408 return success();
1409 }
1410};
1411
1412namespace {
1413
1414LogicalResult updateFieldRefValFromFieldDef(
1415 FieldRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
1416) {
1417 SymbolTableCollection tables;
1418 auto def = op.getFieldDefOp(tables);
1419 if (failed(def)) {
1420 return failure();
1421 }
1422 Type oldResultType = op.getVal().getType();
1423 Type newResultType = def->get().getType();
1424 if (oldResultType == newResultType ||
1425 !tracker.isLegalConversion(oldResultType, newResultType, "updateFieldRefValFromFieldDef")) {
1426 return failure();
1427 }
1428 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
1429 LLVM_DEBUG(
1430 llvm::dbgs() << "[updateFieldRefValFromFieldDef] updated value type in " << op << '\n'
1431 );
1432 return success();
1433}
1434
1435} // namespace
1436
1438class UpdateFieldReadValFromDef final : public OpRewritePattern<FieldReadOp> {
1439 ConversionTracker &tracker_;
1440
1441public:
1442 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1443 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1444
1445 LogicalResult matchAndRewrite(FieldReadOp op, PatternRewriter &rewriter) const override {
1446 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1447 }
1448};
1449
1451class UpdateFieldWriteValFromDef final : public OpRewritePattern<FieldWriteOp> {
1452 ConversionTracker &tracker_;
1453
1454public:
1455 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1456 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1457
1458 LogicalResult matchAndRewrite(FieldWriteOp op, PatternRewriter &rewriter) const override {
1459 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1460 }
1461};
1462
1463LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1464 MLIRContext *ctx = modOp.getContext();
1465 RewritePatternSet patterns(ctx);
1466 patterns.add<
1467 // Benefit of this one must be higher than rules that would propagate the type in the opposite
1468 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
1469 // benefit = 6
1470 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
1471 // benefit = 3
1472 UpdateGlobalCallOpTypes, // CallOp, targeting non-struct functions
1473 UpdateFuncTypeFromReturn, // FuncDefOp
1474 UpdateNewArrayElemFromWrite, // CreateArrayOp
1475 UpdateArrayElemFromArrRead, // ReadArrayOp
1476 UpdateArrayElemFromArrWrite, // WriteArrayOp
1477 UpdateFieldDefTypeFromWrite, // FieldDefOp
1478 UpdateFieldReadValFromDef, // FieldReadOp
1479 UpdateFieldWriteValFromDef // FieldWriteOp
1480 >(ctx, tracker);
1481
1482 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1483}
1484} // namespace Step4_PropagateTypes
1485
1486namespace Step5_Cleanup {
1487
1488class CleanupBase {
1489public:
1490 SymbolTableCollection tables;
1491
1492 CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph)
1493 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1494
1495protected:
1496 ModuleOp rootMod;
1497 const SymbolDefTree &defTree;
1498 const SymbolUseGraph &useGraph;
1499};
1500
1501struct FromKeepSet : public CleanupBase {
1502 using CleanupBase::CleanupBase;
1503
1507 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1508 // Initialize roots from the given StructDefOp instances
1509 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1510 // Add GlobalDefOp and "free functions" to the set of roots
1511 rootMod.walk([&roots](Operation *op) {
1512 if (global::GlobalDefOp gdef = llvm::dyn_cast<global::GlobalDefOp>(op)) {
1513 roots.insert(gdef);
1514 } else if (function::FuncDefOp fdef = llvm::dyn_cast<function::FuncDefOp>(op)) {
1515 if (!fdef.isInStruct()) {
1516 roots.insert(fdef);
1517 }
1518 }
1519 });
1520
1521 // Use a SymbolDefTree to find all Symbol defs reachable from one of the root nodes. Then
1522 // collect all Symbol uses reachable from those def nodes. These are the symbols that should
1523 // be preserved. All other symbol defs should be removed.
1524 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1525 for (size_t i = 0; i < roots.size(); ++i) { // iterate for safe insertion
1526 SymbolOpInterface keepRoot = roots[i];
1527 LLVM_DEBUG({ llvm::dbgs() << "[EraseUnreachable] root: " << keepRoot << '\n'; });
1528 const SymbolDefTreeNode *keepRootNode = defTree.lookupNode(keepRoot);
1529 assert(keepRootNode && "every struct def must be in the def tree");
1530 for (const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1531 LLVM_DEBUG({
1532 llvm::dbgs() << "[EraseUnreachable] can reach: " << reachableDefNode->getOp() << '\n';
1533 });
1534 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1535 // Use 'depth_first_ext()' to get all symbol uses reachable from the current Symbol def
1536 // node. There are no uses if the node is not in the graph. Within the loop that populates
1537 // 'depth_first_ext()', also check if the symbol is a StructDefOp and ensure it is in
1538 // 'roots' so the outer loop will ensure that all symbols reachable from it are preserved.
1539 if (const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1540 for (const SymbolUseGraphNode *usedSymbolNode :
1541 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1542 LLVM_DEBUG({
1543 llvm::dbgs() << "[EraseUnreachable] uses symbol: "
1544 << usedSymbolNode->getSymbolPath() << '\n';
1545 });
1546 // Ignore struct/template parameter symbols (before doing the lookup below because it
1547 // would fail anyway and then cause the "failed" case to be triggered unnecessarily).
1548 if (usedSymbolNode->isStructParam()) {
1549 continue;
1550 }
1551 // If `usedSymbolNode` references a StructDefOp, ensure it's considered in the roots.
1552 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1553 if (failed(lookupRes)) {
1554 LLVM_DEBUG(useGraph.dumpToDotFile());
1555 return failure();
1556 }
1557 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1558 if (lookupRes->viaInclude()) {
1559 continue;
1560 }
1561 if (StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1562 bool insertRes = roots.insert(asStruct);
1563 LLVM_DEBUG({
1564 if (insertRes) {
1565 llvm::dbgs() << "[EraseUnreachable] found another root: " << asStruct << '\n';
1566 }
1567 });
1568 }
1569 }
1570 }
1571 }
1572 }
1573 }
1574
1575 rootMod.walk([this, &symbolsToKeep](StructDefOp op) {
1576 const SymbolUseGraphNode *n = this->useGraph.lookupNode(op);
1577 assert(n);
1578 if (!symbolsToKeep.contains(n)) {
1579 LLVM_DEBUG(llvm::dbgs() << "[EraseUnreachable] removing: " << op.getSymName() << '\n');
1580 op.erase();
1581 }
1582
1583 return WalkResult::skip(); // StructDefOp cannot be nested
1584 });
1585
1586 return success();
1587 }
1588};
1589
1590struct FromEraseSet : public CleanupBase {
1591
1593 FromEraseSet(
1594 ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph,
1595 DenseSet<SymbolRefAttr> &&tryToErasePaths
1596 )
1597 : CleanupBase(root, symDefTree, symUseGraph) {
1598 // Convert the set of paths targeted for erasure into a set of the StructDefOp
1599 for (SymbolRefAttr path : tryToErasePaths) {
1600 Operation *lookupFrom = rootMod.getOperation();
1601 auto res = lookupSymbolIn<StructDefOp>(tables, path, lookupFrom, lookupFrom);
1602 assert(succeeded(res) && "inputs must be valid StructDefOp references");
1603 if (!res->viaInclude()) { // do not remove if it's from another source file
1604 tryToErase.insert(res->get());
1605 }
1606 }
1607 }
1608
1609 LogicalResult eraseUnusedStructs() {
1610 // Collect the subset of 'tryToErase' that has no remaining uses.
1611 for (StructDefOp sd : tryToErase) {
1612 collectSafeToErase(sd);
1613 }
1614 // The `visitedPlusSafetyResult` will contain FuncDefOp w/in the StructDefOp so just a single
1615 // loop to `dyn_cast` and `erase()` will cause `use-after-free` errors w/in the `dyn_cast`.
1616 // Instead, reduce the map to only those that should be erased and erase in a separate loop.
1617 for (auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1618 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1619 visitedPlusSafetyResult.erase(it);
1620 }
1621 }
1622 for (auto &[sym, _] : visitedPlusSafetyResult) {
1623 LLVM_DEBUG(llvm::dbgs() << "[EraseIfUnused] removing: " << sym.getNameAttr() << '\n');
1624 sym.erase();
1625 }
1626 return success();
1627 }
1628
1629 const DenseSet<StructDefOp> &getTryToEraseSet() const { return tryToErase; }
1630
1631private:
1633 DenseSet<StructDefOp> tryToErase;
1637 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1639 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1640
1643 bool collectSafeToErase(SymbolOpInterface check) {
1644 assert(check); // pre-condition
1645
1646 // If previously visited, return the safety result.
1647 auto visited = visitedPlusSafetyResult.find(check);
1648 if (visited != visitedPlusSafetyResult.end()) {
1649 return visited->second;
1650 }
1651
1652 // If it's a StructDefOp that is not in `tryToErase` then it cannot be erased.
1653 if (StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1654 if (!tryToErase.contains(sd)) {
1655 visitedPlusSafetyResult[check] = false;
1656 return false;
1657 }
1658 }
1659
1660 // Otherwise, temporarily mark as safe b/c a node cannot keep itself live (and this prevents
1661 // the recursion from getting stuck in an infinite loop).
1662 visitedPlusSafetyResult[check] = true;
1663
1664 // Check if it's safe according to both the def tree and use graph.
1665 // Note: every symbol must have a def node but module symbols may not have a use node.
1666 if (collectSafeToErase(defTree.lookupNode(check))) {
1667 auto useNode = useGraph.lookupNode(check);
1668 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1669 if (!useNode || collectSafeToErase(useNode)) {
1670 return true;
1671 }
1672 }
1673
1674 // Otherwise, revert the safety decision and return it.
1675 visitedPlusSafetyResult[check] = false;
1676 return false;
1677 }
1678
1680 bool collectSafeToErase(const SymbolDefTreeNode *check) {
1681 assert(check); // pre-condition
1682 if (const SymbolDefTreeNode *p = check->getParent()) {
1683 if (SymbolOpInterface checkOp = p->getOp()) { // safe if parent is root
1684 return collectSafeToErase(checkOp);
1685 }
1686 }
1687 return true;
1688 }
1689
1691 bool collectSafeToErase(const SymbolUseGraphNode *check) {
1692 assert(check); // pre-condition
1693 for (const SymbolUseGraphNode *p : check->predecessorIter()) {
1694 if (SymbolOpInterface checkOp = cachedLookup(p)) { // safe if via IncludeOp
1695 if (!collectSafeToErase(checkOp)) {
1696 return false;
1697 }
1698 }
1699 }
1700 return true;
1701 }
1702
1707 SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) {
1708 assert(node && "must provide a node"); // pre-condition
1709 // Check for cached result
1710 auto fromCache = lookupCache.find(node);
1711 if (fromCache != lookupCache.end()) {
1712 return fromCache->second;
1713 }
1714 // Otherwise, perform lookup and cache
1715 auto lookupRes = node->lookupSymbol(tables);
1716 assert(succeeded(lookupRes) && "graph contains node with invalid path");
1717 assert(lookupRes->get() != nullptr && "lookup must return an Operation");
1718 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1719 // NOTE: The SymbolUseGraph does contain nodes for struct parameters which cannot cast to
1720 // SymbolOpInterface. However, those will always be leaf nodes in the SymbolUseGraph and
1721 // therefore will not be traversed by this analysis so directly casting is fine.
1722 SymbolOpInterface actualRes =
1723 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1724 // Cache and return
1725 lookupCache[node] = actualRes;
1726 assert((!actualRes == lookupRes->viaInclude()) && "not found iff included"); // post-condition
1727 return actualRes;
1728 }
1729};
1730
1731} // namespace Step5_Cleanup
1732
1733class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<FlatteningPass> {
1734
1735 void runOnOperation() override {
1736 ModuleOp modOp = getOperation();
1737 if (failed(runOn(modOp))) {
1738 LLVM_DEBUG({
1739 // If the pass failed, dump the current IR.
1740 llvm::dbgs() << "=====================================================================\n";
1741 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << '\n';
1742 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1743 llvm::dbgs() << "=====================================================================\n";
1744 });
1745 signalPassFailure();
1746 }
1747 }
1748
1749 inline LogicalResult runOn(ModuleOp modOp) {
1750 // If the cleanup mode is set to remove anything not reachable from the "Main" struct, do an
1751 // initial pass to remove things that are not reachable (as an optimization) because creating
1752 // an instantiated version of a struct will not cause something to become reachable that was
1753 // not already reachable in parameterized form.
1754 if (cleanupMode == StructCleanupMode::MainAsRoot) {
1755 if (failed(eraseUnreachableFromMainStruct(modOp))) {
1756 return failure();
1757 }
1758 }
1759
1760 {
1761 // Preliminary step: remove empty parameter lists from structs
1762 OpPassManager nestedPM(ModuleOp::getOperationName());
1763 nestedPM.addPass(createEmptyParamListRemoval());
1764 if (failed(runPipeline(nestedPM, modOp))) {
1765 return failure();
1766 }
1767 }
1768
1769 ConversionTracker tracker;
1770 unsigned loopCount = 0;
1771 do {
1772 ++loopCount;
1773 if (loopCount > iterationLimit) {
1774 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit
1775 << " iterations!\n";
1776 return failure();
1777 }
1778 tracker.resetModifiedFlag();
1779
1780 // Find calls to "compute()" that return a parameterized struct and replace it to call a
1781 // flattened version of the struct that has parameters replaced with the constant values.
1782 // Create the necessary instantiated/flattened struct in the same location as the original.
1783 if (failed(Step1_InstantiateStructs::run(modOp, tracker))) {
1784 llvm::errs() << DEBUG_TYPE << " failed while replacing concrete-parameter struct types\n";
1785 return failure();
1786 }
1787
1788 // Unroll loops with known iterations.
1789 if (failed(Step2_Unroll::run(modOp, tracker))) {
1790 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
1791 return failure();
1792 }
1793
1794 // Instantiate affine_map parameters of StructType and ArrayType.
1795 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
1796 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
1797 return failure();
1798 }
1799
1800 // Propagate updated types using the semantics of various ops.
1801 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
1802 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
1803 return failure();
1804 }
1805
1806 LLVM_DEBUG(if (tracker.isModified()) {
1807 llvm::dbgs() << "=====================================================================\n";
1808 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << '\n';
1809 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1810 llvm::dbgs() << "=====================================================================\n";
1811 });
1812 } while (tracker.isModified());
1813
1814 // Perform cleanup according to the 'cleanupMode' option.
1815 switch (cleanupMode) {
1816 case StructCleanupMode::MainAsRoot:
1817 return eraseUnreachableFromMainStruct(modOp, false);
1818 case StructCleanupMode::ConcreteAsRoot:
1819 return eraseUnreachableFromConcreteStructs(modOp);
1820 case StructCleanupMode::Preimage:
1821 return erasePreimageOfInstantiations(modOp, tracker);
1822 case StructCleanupMode::Disabled:
1823 return success();
1824 }
1825 llvm_unreachable("switch cases cover all options");
1826 }
1827
1828 // Erase parameterized structs that were replaced with concrete instantiations.
1829 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod, const ConversionTracker &tracker) {
1830 // TODO: The names from getInstantiatedStructNames() are NOT guaranteed to be paths from the
1831 // "top root" and they also do not indicate a root module so there could be ambiguity. This is a
1832 // broader problem in the FlatteningPass itself so let's just assume, for now, that these are
1833 // paths from the "top root". See [LLZK-286].
1834 Step5_Cleanup::FromEraseSet cleaner(
1835 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
1836 tracker.getInstantiatedStructNames()
1837 );
1838 LogicalResult res = cleaner.eraseUnusedStructs();
1839 if (succeeded(res)) {
1840 // Warn about any structs that were instantiated but still have uses elsewhere.
1841 const SymbolUseGraph *useGraph = nullptr;
1842 rootMod->walk([this, &cleaner, &useGraph](StructDefOp op) {
1843 if (cleaner.getTryToEraseSet().contains(op)) {
1844 // If needed, rebuild use graph to reflect deletions.
1845 if (!useGraph) {
1846 useGraph = &getAnalysis<SymbolUseGraph>();
1847 }
1848 // If the op has any users, report the warning.
1849 if (useGraph->lookupNode(op)->hasPredecessor()) {
1850 op.emitWarning("Parameterized struct still has uses!").report();
1851 }
1852 }
1853 return WalkResult::skip(); // StructDefOp cannot be nested
1854 });
1855 }
1856 return res;
1857 }
1858
1859 LogicalResult eraseUnreachableFromConcreteStructs(ModuleOp rootMod) {
1860 SmallVector<StructDefOp> roots;
1861 rootMod.walk([&roots](StructDefOp op) {
1862 // Note: no need to check if the ConstParamsAttr is empty since `EmptyParamRemovalPass`
1863 // ran earlier.
1864 if (!op.hasConstParamsAttr()) {
1865 roots.push_back(op);
1866 }
1867 return WalkResult::skip(); // StructDefOp cannot be nested
1868 });
1869
1870 Step5_Cleanup::FromKeepSet cleaner(
1871 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1872 );
1873 return cleaner.eraseUnreachableFrom(roots);
1874 }
1875
1876 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod, bool emitWarning = true) {
1877 Step5_Cleanup::FromKeepSet cleaner(
1878 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1879 );
1880 StructDefOp main =
1881 cleaner.tables.getSymbolTable(rootMod).lookup<StructDefOp>(COMPONENT_NAME_MAIN);
1882 if (emitWarning && !main) {
1883 // Emit warning if there is no "Main" because all structs may be removed (only structs that
1884 // are reachable from a global def or free function will be preserved since those constructs
1885 // are not candidate for removal in this pass).
1886 rootMod.emitWarning() << "using option '" << cleanupMode.getArgStr() << '='
1887 << stringifyStructCleanupMode(StructCleanupMode::MainAsRoot)
1888 << "' with no \"" << COMPONENT_NAME_MAIN
1889 << "\" struct may remove all structs!";
1890 }
1891 return cleaner.eraseUnreachableFrom(
1892 main ? ArrayRef<StructDefOp> {main} : ArrayRef<StructDefOp> {}
1893 );
1894 }
1895};
1896
1897} // namespace
1898
1900 return std::make_unique<FlatteningPass>();
1901};
#define DEBUG_TYPE
#define DEBUG_TYPE
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
const SymbolDefTreeNode * getParent() const
Returns the parent node in the tree. The root node will return nullptr.
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool hasPredecessor() const
Return true if this node has any predecessors.
llvm::iterator_range< iterator > predecessorIter() const
Range over predecessor nodes.
Builds a graph structure representing the relationships between symbols and their uses.
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:649
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:412
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.cpp.inc:438
::mlir::Value getResult()
Definition Ops.cpp.inc:1467
::mlir::Value getRvalue()
Definition Ops.cpp.inc:1772
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:617
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:292
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:1112
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.cpp.inc:1143
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:509
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:729
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
Definition Ops.cpp:143
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:919
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1971
bool hasConstParamsAttr()
Return false iff getConstParamsAttr() returns nullptr
Definition Ops.h.inc:998
void setConstParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.cpp.inc:1975
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:69
::mlir::ArrayAttr getParams() const
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:694
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:700
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:688
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:531
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.cpp.inc:522
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Returns the argument types of this function.
Definition Ops.h.inc:591
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:309
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1086
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:622
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:619
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1130
::mlir::Operation::operand_range getOperands()
Definition Ops.cpp.inc:1322
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.cpp.inc:662
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:121
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:239
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:268
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:126
std::unique_ptr< mlir::Pass > createFlatteningPass()
std::unique_ptr< mlir::Pass > createEmptyParamListRemoval()
::llvm::StringRef stringifyStructCleanupMode(StructCleanupMode val)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:221
bool isConcreteType(Type type, bool allowStructParams)
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
Definition Constants.h:23
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:251
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:181
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:255
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
bool isDynamic(IntegerAttr intAttr)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
int64_t fromAPInt(llvm::APInt i)