LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
28#include "llzk/Util/Concepts.h"
29#include "llzk/Util/Debug.h"
34
35#include <mlir/Dialect/Affine/IR/AffineOps.h>
36#include <mlir/Dialect/Affine/LoopUtils.h>
37#include <mlir/Dialect/Arith/IR/Arith.h>
38#include <mlir/Dialect/SCF/IR/SCF.h>
39#include <mlir/Dialect/SCF/Utils/Utils.h>
40#include <mlir/Dialect/Utils/StaticValueUtils.h>
41#include <mlir/IR/Attributes.h>
42#include <mlir/IR/BuiltinAttributes.h>
43#include <mlir/IR/BuiltinOps.h>
44#include <mlir/IR/BuiltinTypes.h>
45#include <mlir/Interfaces/InferTypeOpInterface.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Support/LLVM.h>
48#include <mlir/Support/LogicalResult.h>
49#include <mlir/Transforms/DialectConversion.h>
50#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
51
52#include <llvm/ADT/APInt.h>
53#include <llvm/ADT/DenseMap.h>
54#include <llvm/ADT/DepthFirstIterator.h>
55#include <llvm/ADT/STLExtras.h>
56#include <llvm/ADT/SmallVector.h>
57#include <llvm/ADT/TypeSwitch.h>
58#include <llvm/Support/Debug.h>
59
60// Include the generated base pass class definitions.
61namespace llzk::polymorphic {
62#define GEN_PASS_DECL_FLATTENINGPASS
63#define GEN_PASS_DEF_FLATTENINGPASS
65} // namespace llzk::polymorphic
66
67#include "SharedImpl.h"
68
69#define DEBUG_TYPE "llzk-flatten"
70
71using namespace mlir;
72using namespace llzk;
73using namespace llzk::array;
74using namespace llzk::component;
75using namespace llzk::constrain;
76using namespace llzk::felt;
77using namespace llzk::function;
78using namespace llzk::polymorphic;
79using namespace llzk::polymorphic::detail;
80
81namespace {
82
83class ConversionTracker {
85 bool modified;
88 DenseMap<StructType, StructType> structInstantiations;
90 DenseMap<StructType, StructType> reverseInstantiations;
93 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
94
95public:
96 bool isModified() const { return modified; }
97 void resetModifiedFlag() { modified = false; }
98 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
99
100 void recordInstantiation(StructType oldType, StructType newType) {
101 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
103 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
104 if (forwardResult.second) {
105 // Insertion was successful
106 // ASSERT: The reverse map does not contain this mapping either
107 assert(!reverseInstantiations.contains(newType));
108 reverseInstantiations[newType] = oldType;
109 // Set the modified flag
110 modified = true;
111 } else {
112 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
113 assert(forwardResult.first->getSecond() == newType);
114 // ASSERT: The reverse mapping is already present as well
115 assert(reverseInstantiations.lookup(newType) == oldType);
116 }
117 assert(structInstantiations.size() == reverseInstantiations.size());
118 }
119
120 /// Return the instantiated type of the given StructType, if any.
121 std::optional<StructType> getInstantiation(StructType oldType) const {
122 auto cachedResult = structInstantiations.find(oldType);
123 if (cachedResult != structInstantiations.end()) {
124 return cachedResult->second;
125 }
126 return std::nullopt;
127 }
128
129
130 DenseSet<SymbolRefAttr> getInstantiatedStructNames() const {
131 DenseSet<SymbolRefAttr> instantiatedNames;
132 for (const auto &[origRemoteTy, _] : structInstantiations) {
133 instantiatedNames.insert(origRemoteTy.getNameRef());
134 }
135 return instantiatedNames;
137
138 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
139 auto res = delayedDiagnostics.find(newType);
140 if (res == delayedDiagnostics.end()) {
141 return;
142 }
143
144 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
145 for (Diagnostic &diag : res->second) {
146 // Update any notes referencing an UnknownLoc to use the CallOp location.
147 for (Diagnostic &note : diag.getNotes()) {
148 assert(note.getNotes().empty() && "notes cannot have notes attached");
149 if (llvm::isa<UnknownLoc>(note.getLocation())) {
150 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
151 }
152 }
153 // Report. Based on InFlightDiagnostic::report().
154 engine.emit(std::move(diag));
155 }
156 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
157 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
158 // `compute()` calls it will only be reported at one of those locations, not all.
159 delayedDiagnostics.erase(newType);
160 }
161
162 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
163 return delayedDiagnostics[newType];
165
168 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
169 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
170 // Check if `oTy` is a struct with a known instantiation to `nTy`
171 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
172 // Note: The values in `structInstantiations` must be no-parameter struct types
173 // so there is no need for recursive check, simple equality is sufficient.
174 if (this->structInstantiations.lookup(oldStructType) == nTy) {
175 return true;
176 }
177 }
178 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
179 // that instantiation (i.e., the parameterized version of the instantiated struct)
180 // is a more concrete unification of `oTy`.
181 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
182 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
183 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
184 return true;
185 }
186 }
187 }
188 return false;
189 };
190
191 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
192 return true;
193 }
194 LLVM_DEBUG(
195 llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
196 << " with new type " << newType
197 << " because it does not define a compatible and more concrete type.\n";
198 );
199 return false;
200 }
201
202 template <typename T, typename U>
203 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
204 return llvm::all_of(
205 llvm::zip_equal(oldTypes, newTypes), [this, &patName](std::tuple<Type, Type> oldThenNew) {
206 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
207 }
208 );
209 }
210};
211
214struct MatchFailureListener : public RewriterBase::Listener {
215 bool hadFailure = false;
216
217 ~MatchFailureListener() override {}
218
219 void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
220 hadFailure = true;
221
222 InFlightDiagnostic diag = emitError(loc);
223 reasonCallback(*diag.getUnderlyingDiagnostic());
224 diag.report();
225 }
226};
227
228static LogicalResult
229applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
230 bool currStepModified = false;
231 MatchFailureListener failureListener;
232 LogicalResult result = applyPatternsGreedily(
233 modOp->getRegion(0), std::move(patterns),
234 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
235 &currStepModified
236 );
237 tracker.updateModifiedFlag(currStepModified);
238 return failure(result.failed() || failureListener.hadFailure);
239}
240
241template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
242 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
243 return isConcreteType(tyAttr.getValue(), AllowStructParams);
244 }
245 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
246 return !isDynamic(intAttr);
247 }
248 return false;
249}
250
252
253static inline bool tableOffsetIsntSymbol(FieldReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
255}
256
259class StructCloner {
260 ConversionTracker &tracker_;
261 ModuleOp rootMod;
262 SymbolTableCollection symTables;
263
264 class MappedTypeConverter : public TypeConverter {
265 StructType origTy;
266 StructType newTy;
267 const DenseMap<Attribute, Attribute> &paramNameToValue;
268
269 inline Attribute convertIfPossible(Attribute a) const {
270 auto res = this->paramNameToValue.find(a);
271 return (res != this->paramNameToValue.end()) ? res->second : a;
272 }
273
274 public:
275 MappedTypeConverter(
276 StructType originalType, StructType newType,
278 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
279 )
280 : TypeConverter(), origTy(originalType), newTy(newType),
281 paramNameToValue(paramNameToInstantiatedValue) {
282
283 addConversion([](Type inputTy) { return inputTy; });
284
285 addConversion([this](StructType inputTy) {
286 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
287
288 // Check for replacement of the full type
289 if (inputTy == this->origTy) {
290 return this->newTy;
291 }
292 // Check for replacement of parameter symbol names with concrete values
293 if (ArrayAttr inputTyParams = inputTy.getParams()) {
294 SmallVector<Attribute> updated;
295 for (Attribute a : inputTyParams) {
296 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
297 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
298 } else {
299 updated.push_back(convertIfPossible(a));
300 }
301 }
302 return StructType::get(
303 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
304 );
305 }
306 // Otherwise, return the type unchanged
307 return inputTy;
308 });
309
310 addConversion([this](ArrayType inputTy) {
311 // Check for replacement of parameter symbol names with concrete values
312 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
313 if (!dimSizes.empty()) {
314 SmallVector<Attribute> updated;
315 for (Attribute a : dimSizes) {
316 updated.push_back(convertIfPossible(a));
317 }
318 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
319 }
320 // Otherwise, return the type unchanged
321 return inputTy;
322 });
323
324 addConversion([this](TypeVarType inputTy) -> Type {
325 // Check for replacement of parameter symbol name with a concrete type
326 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
327 Type convertedType = tyAttr.getValue();
328 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
329 // different struct references a parameter name from that other struct, not from the
330 // current struct so the reference would be invalid.
331 if (isConcreteType(convertedType)) {
332 return convertedType;
333 }
334 }
335 return inputTy;
336 });
337 }
338 };
339
340 template <typename Impl, typename Op, typename... HandledAttrs>
341 class SymbolUserHelper : public OpConversionPattern<Op> {
342 private:
343 const DenseMap<Attribute, Attribute> &paramNameToValue;
344
345 SymbolUserHelper(
346 TypeConverter &converter, MLIRContext *ctx, unsigned Benefit,
347 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
348 )
349 : OpConversionPattern<Op>(converter, ctx, Benefit),
350 paramNameToValue(paramNameToInstantiatedValue) {}
351
352 public:
353 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
354
355 virtual Attribute getNameAttr(Op) const = 0;
356
357 virtual LogicalResult handleDefaultRewrite(
358 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
359 ) const {
360 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
361 }
362
363 LogicalResult
364 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
365 auto res = this->paramNameToValue.find(getNameAttr(op));
366 if (res == this->paramNameToValue.end()) {
367 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] no instantiation for " << op << '\n');
368 return failure();
369 }
370 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
371 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
372
373 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
374 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
375 }))),
376 ...);
377
378 return TS.Default([&](Attribute a) {
379 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
380 });
381 }
382 friend Impl;
383 };
384
385 class ClonedStructConstReadOpPattern
386 : public SymbolUserHelper<
387 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
388 SmallVector<Diagnostic> &diagnostics;
389
390 using super =
391 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
392
393 public:
394 ClonedStructConstReadOpPattern(
395 TypeConverter &converter, MLIRContext *ctx,
396 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
397 SmallVector<Diagnostic> &instantiationDiagnostics
398 )
399 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
400 // instead of the GeneralTypeReplacePattern<ConstReadOp> from newGeneralRewritePatternSet().
401 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue),
402 diagnostics(instantiationDiagnostics) {}
403
404 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
405
406 LogicalResult handleRewrite(
407 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
408 ) const {
409 APInt attrValue = a.getValue();
410 Type origResTy = op.getType();
411 if (llvm::isa<FeltType>(origResTy)) {
413 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
414 );
415 return success();
416 }
417
418 if (llvm::isa<IndexType>(origResTy)) {
420 return success();
421 }
422
423 if (origResTy.isSignlessInteger(1)) {
424 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
425 if (attrValue.isZero()) {
426 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
427 return success();
428 }
429 if (!attrValue.isOne()) {
430 Location opLoc = op.getLoc();
431 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
432 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
433 if (getContext()->shouldPrintOpOnDiagnostic()) {
434 diag.attachNote(opLoc) << "see current operation: " << *op;
435 }
436 diag.attachNote(UnknownLoc::get(getContext()))
437 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \""
438 << sym << "\" for this call";
439 diagnostics.push_back(std::move(diag));
440 }
441 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
442 return success();
443 }
444 return op->emitOpError().append("unexpected result type ", origResTy);
445 }
446
447 LogicalResult handleRewrite(
448 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
449 ) const {
450 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
451 return success();
452 }
453 };
454
455 class ClonedStructFieldReadOpPattern
456 : public SymbolUserHelper<
457 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
458 using super =
459 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
460
461 public:
462 ClonedStructFieldReadOpPattern(
463 TypeConverter &converter, MLIRContext *ctx,
464 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
465 )
466 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
467 // instead of the GeneralTypeReplacePattern<FieldReadOp> from newGeneralRewritePatternSet().
468 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue) {}
469
470 Attribute getNameAttr(FieldReadOp op) const override {
471 return op.getTableOffset().value_or(nullptr);
472 }
473
474 template <typename Attr>
475 LogicalResult handleRewrite(
476 Attribute, FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
477 ) const {
478 rewriter.modifyOpInPlace(op, [&]() {
479 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
480 });
481
482 return success();
483 }
484
485 LogicalResult matchAndRewrite(
486 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
487 ) const override {
488 if (tableOffsetIsntSymbol(op)) {
489 return failure();
490 }
491
492 return super::matchAndRewrite(op, adaptor, rewriter);
493 }
494 };
495
496 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
497 // Find the StructDefOp for the original StructType
498 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.getDefinition(symTables, rootMod);
499 if (failed(r)) {
500 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
501 return failure(); // getDefinition() already emits a sufficient error message
502 }
503
504 StructDefOp origStruct = r->get();
505 StructType typeAtDef = origStruct.getType();
506 MLIRContext *ctx = origStruct.getContext();
507
508 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
509 DenseMap<Attribute, Attribute> paramNameToConcrete;
510 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
511 // where the original attribute from the current instantiation site was not concrete. This is
512 // used for generating the new struct name. See `BuildShortTypeString::from()`.
513 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
514 // Parameter list for the new StructDefOp containing the names that must be preserved because
515 // they were not assigned concrete values at the current instantiation site.
516 ArrayAttr reducedParamNameList = nullptr;
517 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
518 ArrayAttr reducedCallerParams = nullptr;
519 {
520 ArrayAttr paramNames = typeAtDef.getParams();
521
522 // pre-conditions
523 assert(!isNullOrEmpty(paramNames));
524 assert(paramNames.size() == typeAtCallerParams.size());
525
526 SmallVector<Attribute> remainingNames;
527 SmallVector<Attribute> nonConcreteParams;
528 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
529 Attribute next = typeAtCallerParams[i];
530 if (isConcreteAttr<false>(next)) {
531 paramNameToConcrete[paramNames[i]] = next;
532 attrsForInstantiatedNameSuffix.push_back(next);
533 } else {
534 remainingNames.push_back(paramNames[i]);
535 nonConcreteParams.push_back(next);
536 attrsForInstantiatedNameSuffix.push_back(nullptr);
537 }
538 }
539 // post-conditions
540 assert(remainingNames.size() == nonConcreteParams.size());
541 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
542 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
543
544 if (paramNameToConcrete.empty()) {
545 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
546 return failure();
547 }
548 if (!remainingNames.empty()) {
549 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
550 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
551 }
552 }
553
554 // Clone the original struct, apply the new name, and set the parameter list of the new struct
555 // to contain only those that did not have concrete instantiated values.
556 StructDefOp newStruct = origStruct.clone();
557 newStruct.setConstParamsAttr(reducedParamNameList);
558 newStruct.setSymName(
560 typeAtCaller.getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
561 )
562 );
563
564 // Insert 'newStruct' into the parent ModuleOp of the original StructDefOp. Use the
565 // `SymbolTable::insert()` function directly so that the name will be made unique.
566 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>(); // parent is ModuleOp per ODS
567 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
568 // Retrieve the new type AFTER inserting since the name may be appended to make it unique and
569 // use the remaining non-concrete parameters from the original type.
570 StructType newRemoteType = newStruct.getType(reducedCallerParams);
571 LLVM_DEBUG({
572 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
573 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
574 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
575 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
576 });
577
578 // Within the new struct, replace all references to the original StructType (i.e., the
579 // locally-parameterized version) with the new locally-parameterized StructType,
580 // and replace all uses of the removed struct parameters with the concrete values.
581 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
582 ConversionTarget target =
583 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
584 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
585 // Legal if it's not in the map of concrete attribute instantiations
586 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
587 });
588
589 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
590 patterns.add<ClonedStructConstReadOpPattern>(
591 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
592 );
593 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
594 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
595 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
596 return failure();
597 }
598 return newRemoteType;
599 }
600
601public:
602 StructCloner(ConversionTracker &tracker, ModuleOp root)
603 : tracker_(tracker), rootMod(root), symTables() {}
604
605 FailureOr<StructType> createInstantiatedClone(StructType orig) {
606 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
607 if (ArrayAttr params = orig.getParams()) {
608 return genClone(orig, params.getValue());
609 }
610 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
611 return failure();
612 }
613};
614
615class ParameterizedStructUseTypeConverter : public TypeConverter {
616 ConversionTracker &tracker_;
617 StructCloner cloner;
618
619public:
620 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
621 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
622
623 addConversion([](Type inputTy) { return inputTy; });
624
625 addConversion([this](StructType inputTy) -> StructType {
626 // First check for a cached entry
627 if (auto opt = tracker_.getInstantiation(inputTy)) {
628 return opt.value();
629 }
630
631 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
632 // done, return the original type to indicate that it's still legal (for this step at least).
633 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
634 if (failed(cloneRes)) {
635 return inputTy;
636 }
637 StructType newTy = cloneRes.value();
638 LLVM_DEBUG(
639 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
640 << " as " << newTy << '\n'
641 );
642 tracker_.recordInstantiation(inputTy, newTy);
643 return newTy;
644 });
645
646 addConversion([this](ArrayType inputTy) {
647 return inputTy.cloneWith(convertType(inputTy.getElementType()));
648 });
649 }
650};
651
652class CallStructFuncPattern : public OpConversionPattern<CallOp> {
653 ConversionTracker &tracker_;
654
655public:
656 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
657 // Must use higher benefit than CallOpClassReplacePattern so this pattern will be applied
658 // instead of the CallOpClassReplacePattern from newGeneralRewritePatternSet().
659 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/2), tracker_(tracker) {}
660
661 LogicalResult matchAndRewrite(
662 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
663 ) const override {
664 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
665
666 // Convert the result types of the CallOp
667 SmallVector<Type> newResultTypes;
668 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
669 return op->emitError("Could not convert Op result types.");
670 }
671 LLVM_DEBUG({
672 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
673 << debug::toStringList(newResultTypes) << '\n';
674 });
675
676 // Update the callee to reflect the new struct target if necessary. These checks are based on
677 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
678 // Instead they must come from the converted versions.
679 SymbolRefAttr calleeAttr = op.getCalleeAttr();
680 if (op.calleeIsStructCompute()) {
681 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
682 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
683 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
684 tracker_.reportDelayedDiagnostics(newStTy, op);
685 }
686 } else if (op.calleeIsStructConstrain()) {
687 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
688 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
689 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
690 }
691 }
692
693 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
695 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
696 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
697 );
698 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
699 return success();
700 }
701};
702
703// This one ensures FieldDefOp types are converted even if there are no reads/writes to them.
704class FieldDefOpPattern : public OpConversionPattern<FieldDefOp> {
705public:
706 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
707 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
708 // instead of the GeneralTypeReplacePattern<FieldDefOp> from newGeneralRewritePatternSet().
709 : OpConversionPattern<FieldDefOp>(converter, ctx, /*benefit=*/2) {}
710
711 LogicalResult matchAndRewrite(
712 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
713 ) const override {
714 LLVM_DEBUG(llvm::dbgs() << "[FieldDefOpPattern] FieldDefOp: " << op << '\n');
715
716 Type oldFieldType = op.getType();
717 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
718 if (oldFieldType == newFieldType) {
719 // nothing changed
720 return failure();
721 }
722 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.setType(newFieldType); });
723 return success();
724 }
725};
726
727LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
728 MLIRContext *ctx = modOp.getContext();
729 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
730 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
731 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
732 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
733 return applyPartialConversion(modOp, target, std::move(patterns));
734}
735
736} // namespace Step1_InstantiateStructs
737
738namespace Step2_Unroll {
739
740// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
741template <HasInterface<LoopLikeOpInterface> OpClass>
742class LoopUnrollPattern : public OpRewritePattern<OpClass> {
743public:
744 using OpRewritePattern<OpClass>::OpRewritePattern;
745
746 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
747 if (auto maybeConstant = getConstantTripCount(loopOp)) {
748 uint64_t tripCount = *maybeConstant;
749 if (tripCount == 0) {
750 rewriter.eraseOp(loopOp);
751 return success();
752 } else if (tripCount == 1) {
753 return loopOp.promoteIfSingleIteration(rewriter);
754 }
755 return loopUnrollByFactor(loopOp, tripCount);
756 }
757 return failure();
758 }
759
760private:
763 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
764 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
765 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
766 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
767 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
768 return std::nullopt;
769 }
770 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
771 }
772};
773
774LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
775 MLIRContext *ctx = modOp.getContext();
776 RewritePatternSet patterns(ctx);
777 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
778 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
779
780 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
781}
782} // namespace Step2_Unroll
783
785
786// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
787// version uses a basic loop instead of llvm::map_to_vector().
788std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
789 SmallVector<int64_t> res;
790 for (OpFoldResult ofr : ofrs) {
791 std::optional<int64_t> cv = getConstantIntValue(ofr);
792 if (!cv.has_value()) {
793 return std::nullopt;
794 }
795 res.push_back(cv.value());
796 }
797 return res;
798}
799
800struct AffineMapFolder {
801 struct Input {
802 OperandRangeRange mapOpGroups;
803 DenseI32ArrayAttr dimsPerGroup;
804 ArrayRef<Attribute> paramsOfStructTy;
805 };
806
807 struct Output {
808 SmallVector<SmallVector<Value>> mapOpGroups;
809 SmallVector<int32_t> dimsPerGroup;
810 SmallVector<Attribute> paramsOfStructTy;
811 };
812
813 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
814 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
815 return ValueRange(grp);
816 });
817 }
818
819 static LogicalResult
820 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
821 if (in.mapOpGroups.empty()) {
822 // No affine map operands so nothing to do
823 return failure();
824 }
825
826 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
827 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
828
829 size_t idx = 0; // index in `mapOpGroups`, i.e., the number of AffineMapAttr encountered
830 for (Attribute sizeAttr : in.paramsOfStructTy) {
831 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
832 ValueRange currMapOps = in.mapOpGroups[idx++];
833 LLVM_DEBUG(
834 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
835 << '\n'
836 );
837 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
838 LLVM_DEBUG(
839 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
840 << debug::toStringList(currMapOpsCast) << '\n'
841 );
842 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
843 SmallVector<Attribute> result;
844 bool hasPoison = false; // indicates divide by 0 or mod by <1
845 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
846 return rewriter.getIndexAttr(v);
847 });
848 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
849 if (hasPoison) {
850 LLVM_DEBUG(op->emitRemark()
851 .append(
852 "Cannot fold affine_map for ", aspect, " ",
853 out.paramsOfStructTy.size(),
854 " due to divide by 0 or modulus with negative divisor"
855 )
856 .report());
857 return failure();
858 }
859 if (failed(foldResult)) {
860 LLVM_DEBUG(op->emitRemark()
861 .append(
862 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
863 " failed"
864 )
865 .report());
866 return failure();
867 }
868 if (result.size() != 1) {
869 LLVM_DEBUG(op->emitRemark()
870 .append(
871 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
872 " produced ", result.size(), " results but expected 1"
873 )
874 .report());
875 return failure();
876 }
877 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
878 out.paramsOfStructTy.push_back(result[0]);
879 continue;
880 }
881 // If affine but not foldable, preserve the map ops
882 out.mapOpGroups.emplace_back(currMapOps);
883 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
884 }
885 // If not affine and foldable, preserve the original
886 out.paramsOfStructTy.push_back(sizeAttr);
887 }
888 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
889 assert(
890 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
891 "produced wrong number of dimensions"
892 );
893
894 return success();
895 }
896};
897
899class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
900 ConversionTracker &tracker_;
901
902public:
903 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
904 : OpRewritePattern(ctx), tracker_(tracker) {}
905
906 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
907 ArrayType oldResultType = op.getType();
908
909 AffineMapFolder::Output out;
910 AffineMapFolder::Input in = {
911 op.getMapOperands(),
913 oldResultType.getDimensionSizes(),
914 };
915 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
916 return failure();
917 }
918
919 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
920 if (newResultType == oldResultType) {
921 // nothing changed
922 return failure();
923 }
924 // ASSERT: folding only preserves the original Attribute or converts affine to integer
925 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
926 LLVM_DEBUG(
927 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
928 << newResultType << " in \"" << op << "\"\n"
929 );
931 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
932 );
933 return success();
934 }
935};
936
938class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
939 ConversionTracker &tracker_;
940
941public:
942 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
943 : OpRewritePattern(ctx), tracker_(tracker) {}
944
945 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
946 if (!op.calleeIsStructCompute()) {
947 // this pattern only applies when the callee is "compute()" within a struct
948 return failure();
949 }
950 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
952 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
953 ArrayAttr params = oldRetTy.getParams();
954 if (isNullOrEmpty(params)) {
955 // nothing to do if the StructType is not parameterized
956 return failure();
957 }
958
959 AffineMapFolder::Output out;
960 AffineMapFolder::Input in = {
961 op.getMapOperands(),
963 params.getValue(),
964 };
965 if (!in.mapOpGroups.empty()) {
966 // If there are affine map operands, attempt to fold them to a constant.
967 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
968 return failure();
969 }
970 LLVM_DEBUG({
971 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
972 });
973 } else {
974 // If there are no affine map operands, attempt to refine the result type of the CallOp using
975 // the function argument types and the type of the target function.
976 auto callArgTypes = op.getArgOperands().getTypes();
977 if (callArgTypes.empty()) {
978 // no refinement possible if no function arguments
979 return failure();
980 }
981 SymbolTableCollection tables;
982 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
983 if (failed(lookupRes)) {
984 return failure();
985 }
986 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
987 return failure();
988 }
989 LLVM_DEBUG({
990 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
991 "result type params: "
992 << debug::toStringList(out.paramsOfStructTy) << '\n';
993 });
994 }
995
996 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
997 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
998 if (newRetTy == oldRetTy) {
999 // nothing changed
1000 return failure();
1001 }
1002 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
1003 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
1004 // the conversion here is illegal it means there is a type conflict within the LLZK code that
1005 // prevents instantiation of the struct with the requested type.
1006 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
1007 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1008 diag.append(
1009 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1010 ", but found ", oldRetTy
1011 );
1012 });
1013 }
1014 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1016 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1017 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1018 );
1019 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1020 return success();
1021 }
1022
1023private:
1026 inline LogicalResult instantiateViaTargetType(
1027 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1028 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1029 ) const {
1030 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1031 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1032 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1033 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1034
1035 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1036 // Nothing can change if everything is already concrete
1037 return failure();
1038 }
1039
1040 LLVM_DEBUG({
1041 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1042 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1043 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1044 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1045 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1046 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1047 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1048 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1049 });
1050
1051 UnificationMap unifications;
1052 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1053 assert(unifies && "should have been checked by verifiers");
1054
1055 LLVM_DEBUG({
1056 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1057 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1058 });
1059
1060 // Check for LHS SymRef (i.e., from the target function) that have RHS concrete Attributes (i.e.
1061 // from the call argument types) without any struct parameters (because the type with concrete
1062 // struct parameters will be used to instantiate the target struct rather than the fully
1063 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1064 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1065 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1066 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1067 [&unifications](std::tuple<Attribute, Attribute> p) {
1068 Attribute fromCall = std::get<1>(p);
1069 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1070 // non-parameterized concrete unification for the target struct parameter symbol.
1071 if (!isConcreteAttr<>(fromCall)) {
1072 Attribute fromTgt = std::get<0>(p);
1073 LLVM_DEBUG({
1074 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1075 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1076 });
1077 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1078 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1079 if (it != unifications.end()) {
1080 Attribute unifiedAttr = it->second;
1081 LLVM_DEBUG({
1082 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1083 });
1084 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1085 return unifiedAttr;
1086 }
1087 }
1088 }
1089 return fromCall;
1090 }
1091 );
1092
1093 out.paramsOfStructTy = newReturnStructParams;
1094 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1095 assert(out.mapOpGroups.empty() && "post-condition");
1096 assert(out.dimsPerGroup.empty() && "post-condition");
1097 return success();
1098 }
1099};
1100
1101LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1102 MLIRContext *ctx = modOp.getContext();
1103 RewritePatternSet patterns(ctx);
1104 patterns.add<
1105 InstantiateAtCreateArrayOp, // CreateArrayOp
1106 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1107 >(ctx, tracker);
1108
1109 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1110}
1111
1112} // namespace Step3_InstantiateAffineMaps
1113
1115
1117class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1118 ConversionTracker &tracker_;
1119
1120public:
1121 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1122 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1123
1124 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1125 Value createResult = op.getResult();
1126 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1127 assert(createResultType && "CreateArrayOp must produce ArrayType");
1128 Type oldResultElemType = createResultType.getElementType();
1129
1130 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1131 // element type is different.
1132 Type newResultElemType = nullptr;
1133 for (Operation *user : createResult.getUsers()) {
1134 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1135 if (writeOp.getArrRef() != createResult) {
1136 continue;
1137 }
1138 Type writeRValueType = writeOp.getRvalue().getType();
1139 if (writeRValueType == oldResultElemType) {
1140 continue;
1141 }
1142 if (newResultElemType && newResultElemType != writeRValueType) {
1143 LLVM_DEBUG(
1144 llvm::dbgs()
1145 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1146 << newResultElemType << " vs " << writeRValueType << '\n'
1147 );
1148 return failure();
1149 }
1150 newResultElemType = writeRValueType;
1151 }
1152 }
1153 if (!newResultElemType) {
1154 // no replacement type found
1155 return failure();
1156 }
1157 if (!tracker_.isLegalConversion(
1158 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1159 )) {
1160 return failure();
1161 }
1162 ArrayType newType = createResultType.cloneWith(newResultElemType);
1163 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1164 LLVM_DEBUG(
1165 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1166 );
1167 return success();
1168 }
1169};
1170
1171namespace {
1172
1173LogicalResult updateArrayElemFromArrAccessOp(
1174 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1175 PatternRewriter &rewriter
1176) {
1177 ArrayType oldArrType = op.getArrRefType();
1178 if (oldArrType.getElementType() == scalarElemTy) {
1179 return failure(); // no change needed
1180 }
1181 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1182 if (oldArrType == newArrType ||
1183 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1184 return failure();
1185 }
1186 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1187 LLVM_DEBUG(
1188 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1189 );
1190 return success();
1191}
1192
1193} // namespace
1194
1195class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1196 ConversionTracker &tracker_;
1197
1198public:
1199 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1200 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1201
1202 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1203 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1204 }
1205};
1206
1207class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1208 ConversionTracker &tracker_;
1209
1210public:
1211 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1212 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1213
1214 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1215 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1216 }
1217};
1218
1220class UpdateFieldDefTypeFromWrite final : public OpRewritePattern<FieldDefOp> {
1221 ConversionTracker &tracker_;
1222
1223public:
1224 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1225 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1226
1227 LogicalResult matchAndRewrite(FieldDefOp op, PatternRewriter &rewriter) const override {
1228 // Find all uses of the field symbol name within its parent struct.
1229 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(op);
1230 assert(succeeded(parentRes) && "FieldDefOp parent is always StructDefOp"); // per ODS def
1231
1232 // If the symbol is used by a FieldWriteOp with a different result type then change
1233 // the type of the FieldDefOp to match the FieldWriteOp result type.
1234 Type newType = nullptr;
1235 if (auto fieldUsers = llzk::getSymbolUses(op, parentRes.value())) {
1236 std::optional<Location> newTypeLoc = std::nullopt;
1237 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1238 if (FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1239 Type writeToType = writeOp.getVal().getType();
1240 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] checking " << writeOp << '\n');
1241 if (!newType) {
1242 // If a new type has not yet been discovered, store the new type.
1243 newType = writeToType;
1244 newTypeLoc = writeOp.getLoc();
1245 } else if (writeToType != newType) {
1246 // Typically, there will only be one write for each field of a struct but do not rely on
1247 // that assumption. If multiple writes with a different types A and B are found where
1248 // A->B is a legal conversion (i.e., more concrete unification), then it is safe to use
1249 // type B with the assumption that the write with type A will be updated by another
1250 // pattern to also use type B.
1251 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateFieldDefTypeFromWrite")) {
1252 if (tracker_.isLegalConversion(newType, writeToType, "UpdateFieldDefTypeFromWrite")) {
1253 // 'writeToType' is the more concrete type
1254 newType = writeToType;
1255 newTypeLoc = writeOp.getLoc();
1256 } else {
1257 // Give an error if the types are incompatible.
1258 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1259 diag.append(
1260 "Cannot update type of '", FieldDefOp::getOperationName(),
1261 "' because there are multiple '", FieldWriteOp::getOperationName(),
1262 "' with different value types"
1263 );
1264 if (newTypeLoc) {
1265 diag.attachNote(*newTypeLoc).append("type written here is ", newType);
1266 }
1267 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1268 });
1269 }
1270 }
1271 }
1272 }
1273 }
1274 }
1275 if (!newType || newType == op.getType()) {
1276 // nothing changed
1277 return failure();
1278 }
1279 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateFieldDefTypeFromWrite")) {
1280 return failure();
1281 }
1282 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1283 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] updated type of " << op << '\n');
1284 return success();
1285 }
1286};
1287
1288namespace {
1289
1290SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1291 SmallVector<std::unique_ptr<Region>> newRegions;
1292 for (Region &region : op->getRegions()) {
1293 auto newRegion = std::make_unique<Region>();
1294 newRegion->takeBody(region);
1295 newRegions.push_back(std::move(newRegion));
1296 }
1297 return newRegions;
1298}
1299
1300} // namespace
1301
1304class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1305 ConversionTracker &tracker_;
1306
1307public:
1308 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1309 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1310
1311 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1312 SmallVector<Type, 1> inferredResultTypes;
1313 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1314 LogicalResult result = retTypeFn.inferReturnTypes(
1315 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1316 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1317 );
1318 if (failed(result)) {
1319 return failure();
1320 }
1321 if (op->getResultTypes() == inferredResultTypes) {
1322 // nothing changed
1323 return failure();
1324 }
1325 if (!tracker_.areLegalConversions(
1326 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1327 )) {
1328 return failure();
1329 }
1330
1331 // Move nested region bodies and replace the original op with the updated types list.
1332 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1333 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1334 Operation *newOp = rewriter.create(
1335 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1336 op->getAttrs(), op->getSuccessors(), newRegions
1337 );
1338 rewriter.replaceOp(op, newOp);
1339 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1340 return success();
1341 }
1342};
1343
1345class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1346 ConversionTracker &tracker_;
1347
1348public:
1349 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1350 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1351
1352 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1353 Region &body = op.getFunctionBody();
1354 if (body.empty()) {
1355 return failure();
1356 }
1357 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1358 assert(retOp && "final op in body region must be return");
1359 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1360
1361 FunctionType oldFuncTy = op.getFunctionType();
1362 if (oldFuncTy.getResults() == tyFromReturnOp) {
1363 // nothing changed
1364 return failure();
1365 }
1366 if (!tracker_.areLegalConversions(
1367 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1368 )) {
1369 return failure();
1370 }
1371
1372 rewriter.modifyOpInPlace(op, [&]() {
1373 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1374 });
1375 LLVM_DEBUG(
1376 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
1377 << oldFuncTy << " to " << op.getFunctionType() << '\n'
1378 );
1379 return success();
1380 }
1381};
1382
1387class UpdateGlobalCallOpTypes final : public OpRewritePattern<CallOp> {
1388 ConversionTracker &tracker_;
1389
1390public:
1391 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1392 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1393
1394 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1395 SymbolTableCollection tables;
1396 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1397 if (failed(lookupRes)) {
1398 return failure();
1399 }
1400 FuncDefOp targetFunc = lookupRes->get();
1401 if (targetFunc.isInStruct()) {
1402 // this pattern only applies when the callee is NOT in a struct
1403 return failure();
1404 }
1405 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
1406 // nothing changed
1407 return failure();
1408 }
1409 if (!tracker_.areLegalConversions(
1410 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
1411 "UpdateGlobalCallOpTypes"
1412 )) {
1413 return failure();
1414 }
1415
1416 LLVM_DEBUG(llvm::dbgs() << "[UpdateGlobalCallOpTypes] replaced " << op);
1417 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
1418 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1419 return success();
1420 }
1421};
1422
1423namespace {
1424
1425LogicalResult updateFieldRefValFromFieldDef(
1426 FieldRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
1427) {
1428 SymbolTableCollection tables;
1429 auto def = op.getFieldDefOp(tables);
1430 if (failed(def)) {
1431 return failure();
1432 }
1433 Type oldResultType = op.getVal().getType();
1434 Type newResultType = def->get().getType();
1435 if (oldResultType == newResultType ||
1436 !tracker.isLegalConversion(oldResultType, newResultType, "updateFieldRefValFromFieldDef")) {
1437 return failure();
1438 }
1439 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
1440 LLVM_DEBUG(
1441 llvm::dbgs() << "[updateFieldRefValFromFieldDef] updated value type in " << op << '\n'
1442 );
1443 return success();
1444}
1445
1446} // namespace
1447
1449class UpdateFieldReadValFromDef final : public OpRewritePattern<FieldReadOp> {
1450 ConversionTracker &tracker_;
1451
1452public:
1453 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1454 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1455
1456 LogicalResult matchAndRewrite(FieldReadOp op, PatternRewriter &rewriter) const override {
1457 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1458 }
1459};
1460
1462class UpdateFieldWriteValFromDef final : public OpRewritePattern<FieldWriteOp> {
1463 ConversionTracker &tracker_;
1464
1465public:
1466 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1467 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1468
1469 LogicalResult matchAndRewrite(FieldWriteOp op, PatternRewriter &rewriter) const override {
1470 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1471 }
1472};
1473
1474LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1475 MLIRContext *ctx = modOp.getContext();
1476 RewritePatternSet patterns(ctx);
1477 patterns.add<
1478 // Benefit of this one must be higher than rules that would propagate the type in the opposite
1479 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
1480 // benefit = 6
1481 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
1482 // benefit = 3
1483 UpdateGlobalCallOpTypes, // CallOp, targeting non-struct functions
1484 UpdateFuncTypeFromReturn, // FuncDefOp
1485 UpdateNewArrayElemFromWrite, // CreateArrayOp
1486 UpdateArrayElemFromArrRead, // ReadArrayOp
1487 UpdateArrayElemFromArrWrite, // WriteArrayOp
1488 UpdateFieldDefTypeFromWrite, // FieldDefOp
1489 UpdateFieldReadValFromDef, // FieldReadOp
1490 UpdateFieldWriteValFromDef // FieldWriteOp
1491 >(ctx, tracker);
1492
1493 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1494}
1495} // namespace Step4_PropagateTypes
1496
1497namespace Step5_Cleanup {
1498
1499class CleanupBase {
1500public:
1501 SymbolTableCollection tables;
1502
1503 CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph)
1504 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1505
1506protected:
1507 ModuleOp rootMod;
1508 const SymbolDefTree &defTree;
1509 const SymbolUseGraph &useGraph;
1510};
1511
1512struct FromKeepSet : public CleanupBase {
1513 using CleanupBase::CleanupBase;
1514
1518 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1519 // Initialize roots from the given StructDefOp instances
1520 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1521 // Add GlobalDefOp and "free functions" to the set of roots
1522 rootMod.walk([&roots](Operation *op) {
1523 if (global::GlobalDefOp gdef = llvm::dyn_cast<global::GlobalDefOp>(op)) {
1524 roots.insert(gdef);
1525 } else if (function::FuncDefOp fdef = llvm::dyn_cast<function::FuncDefOp>(op)) {
1526 if (!fdef.isInStruct()) {
1527 roots.insert(fdef);
1528 }
1529 }
1530 });
1531
1532 // Use a SymbolDefTree to find all Symbol defs reachable from one of the root nodes. Then
1533 // collect all Symbol uses reachable from those def nodes. These are the symbols that should
1534 // be preserved. All other symbol defs should be removed.
1535 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1536 for (size_t i = 0; i < roots.size(); ++i) { // iterate for safe insertion
1537 SymbolOpInterface keepRoot = roots[i];
1538 LLVM_DEBUG({ llvm::dbgs() << "[EraseUnreachable] root: " << keepRoot << '\n'; });
1539 const SymbolDefTreeNode *keepRootNode = defTree.lookupNode(keepRoot);
1540 assert(keepRootNode && "every struct def must be in the def tree");
1541 for (const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1542 LLVM_DEBUG({
1543 llvm::dbgs() << "[EraseUnreachable] can reach: " << reachableDefNode->getOp() << '\n';
1544 });
1545 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1546 // Use 'depth_first_ext()' to get all symbol uses reachable from the current Symbol def
1547 // node. There are no uses if the node is not in the graph. Within the loop that populates
1548 // 'depth_first_ext()', also check if the symbol is a StructDefOp and ensure it is in
1549 // 'roots' so the outer loop will ensure that all symbols reachable from it are preserved.
1550 if (const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1551 for (const SymbolUseGraphNode *usedSymbolNode :
1552 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1553 LLVM_DEBUG({
1554 llvm::dbgs() << "[EraseUnreachable] uses symbol: "
1555 << usedSymbolNode->getSymbolPath() << '\n';
1556 });
1557 // Ignore struct/template parameter symbols (before doing the lookup below because it
1558 // would fail anyway and then cause the "failed" case to be triggered unnecessarily).
1559 if (usedSymbolNode->isStructParam()) {
1560 continue;
1561 }
1562 // If `usedSymbolNode` references a StructDefOp, ensure it's considered in the roots.
1563 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1564 if (failed(lookupRes)) {
1565 LLVM_DEBUG(useGraph.dumpToDotFile());
1566 return failure();
1567 }
1568 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1569 if (lookupRes->viaInclude()) {
1570 continue;
1571 }
1572 if (StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1573 bool insertRes = roots.insert(asStruct);
1574 LLVM_DEBUG({
1575 if (insertRes) {
1576 llvm::dbgs() << "[EraseUnreachable] found another root: " << asStruct << '\n';
1577 }
1578 });
1579 }
1580 }
1581 }
1582 }
1583 }
1584 }
1585
1586 rootMod.walk([this, &symbolsToKeep](StructDefOp op) {
1587 const SymbolUseGraphNode *n = this->useGraph.lookupNode(op);
1588 assert(n);
1589 if (!symbolsToKeep.contains(n)) {
1590 LLVM_DEBUG(llvm::dbgs() << "[EraseUnreachable] removing: " << op.getSymName() << '\n');
1591 op.erase();
1592 }
1593
1594 return WalkResult::skip(); // StructDefOp cannot be nested
1595 });
1596
1597 return success();
1598 }
1599};
1600
1601struct FromEraseSet : public CleanupBase {
1602
1604 FromEraseSet(
1605 ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph,
1606 DenseSet<SymbolRefAttr> &&tryToErasePaths
1607 )
1608 : CleanupBase(root, symDefTree, symUseGraph) {
1609 // Convert the set of paths targeted for erasure into a set of the StructDefOp
1610 for (SymbolRefAttr path : tryToErasePaths) {
1611 Operation *lookupFrom = rootMod.getOperation();
1612 auto res = lookupSymbolIn<StructDefOp>(tables, path, lookupFrom, lookupFrom);
1613 assert(succeeded(res) && "inputs must be valid StructDefOp references");
1614 if (!res->viaInclude()) { // do not remove if it's from another source file
1615 tryToErase.insert(res->get());
1616 }
1617 }
1618 }
1619
1620 LogicalResult eraseUnusedStructs() {
1621 // Collect the subset of 'tryToErase' that has no remaining uses.
1622 for (StructDefOp sd : tryToErase) {
1623 collectSafeToErase(sd);
1624 }
1625 // The `visitedPlusSafetyResult` will contain FuncDefOp w/in the StructDefOp so just a single
1626 // loop to `dyn_cast` and `erase()` will cause `use-after-free` errors w/in the `dyn_cast`.
1627 // Instead, reduce the map to only those that should be erased and erase in a separate loop.
1628 for (auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1629 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1630 visitedPlusSafetyResult.erase(it);
1631 }
1632 }
1633 for (auto &[sym, _] : visitedPlusSafetyResult) {
1634 LLVM_DEBUG(llvm::dbgs() << "[EraseIfUnused] removing: " << sym.getNameAttr() << '\n');
1635 sym.erase();
1636 }
1637 return success();
1638 }
1639
1640 const DenseSet<StructDefOp> &getTryToEraseSet() const { return tryToErase; }
1641
1642private:
1644 DenseSet<StructDefOp> tryToErase;
1648 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1650 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1651
1654 bool collectSafeToErase(SymbolOpInterface check) {
1655 assert(check); // pre-condition
1656
1657 // If previously visited, return the safety result.
1658 auto visited = visitedPlusSafetyResult.find(check);
1659 if (visited != visitedPlusSafetyResult.end()) {
1660 return visited->second;
1661 }
1662
1663 // If it's a StructDefOp that is not in `tryToErase` then it cannot be erased.
1664 if (StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1665 if (!tryToErase.contains(sd)) {
1666 visitedPlusSafetyResult[check] = false;
1667 return false;
1668 }
1669 }
1670
1671 // Otherwise, temporarily mark as safe b/c a node cannot keep itself live (and this prevents
1672 // the recursion from getting stuck in an infinite loop).
1673 visitedPlusSafetyResult[check] = true;
1674
1675 // Check if it's safe according to both the def tree and use graph.
1676 // Note: every symbol must have a def node but module symbols may not have a use node.
1677 if (collectSafeToErase(defTree.lookupNode(check))) {
1678 auto useNode = useGraph.lookupNode(check);
1679 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1680 if (!useNode || collectSafeToErase(useNode)) {
1681 return true;
1682 }
1683 }
1684
1685 // Otherwise, revert the safety decision and return it.
1686 visitedPlusSafetyResult[check] = false;
1687 return false;
1688 }
1689
1691 bool collectSafeToErase(const SymbolDefTreeNode *check) {
1692 assert(check); // pre-condition
1693 if (const SymbolDefTreeNode *p = check->getParent()) {
1694 if (SymbolOpInterface checkOp = p->getOp()) { // safe if parent is root
1695 return collectSafeToErase(checkOp);
1696 }
1697 }
1698 return true;
1699 }
1700
1702 bool collectSafeToErase(const SymbolUseGraphNode *check) {
1703 assert(check); // pre-condition
1704 for (const SymbolUseGraphNode *p : check->predecessorIter()) {
1705 if (SymbolOpInterface checkOp = cachedLookup(p)) { // safe if via IncludeOp
1706 if (!collectSafeToErase(checkOp)) {
1707 return false;
1708 }
1709 }
1710 }
1711 return true;
1712 }
1713
1718 SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) {
1719 assert(node && "must provide a node"); // pre-condition
1720 // Check for cached result
1721 auto fromCache = lookupCache.find(node);
1722 if (fromCache != lookupCache.end()) {
1723 return fromCache->second;
1724 }
1725 // Otherwise, perform lookup and cache
1726 auto lookupRes = node->lookupSymbol(tables);
1727 assert(succeeded(lookupRes) && "graph contains node with invalid path");
1728 assert(lookupRes->get() != nullptr && "lookup must return an Operation");
1729 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1730 // NOTE: The SymbolUseGraph does contain nodes for struct parameters which cannot cast to
1731 // SymbolOpInterface. However, those will always be leaf nodes in the SymbolUseGraph and
1732 // therefore will not be traversed by this analysis so directly casting is fine.
1733 SymbolOpInterface actualRes =
1734 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1735 // Cache and return
1736 lookupCache[node] = actualRes;
1737 assert((!actualRes == lookupRes->viaInclude()) && "not found iff included"); // post-condition
1738 return actualRes;
1739 }
1740};
1741
1742} // namespace Step5_Cleanup
1743
1744class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<FlatteningPass> {
1745
1746 void runOnOperation() override {
1747 ModuleOp modOp = getOperation();
1748 if (failed(runOn(modOp))) {
1749 LLVM_DEBUG({
1750 // If the pass failed, dump the current IR.
1751 llvm::dbgs() << "=====================================================================\n";
1752 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << '\n';
1753 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1754 llvm::dbgs() << "=====================================================================\n";
1755 });
1756 signalPassFailure();
1757 }
1758 }
1759
1760 inline LogicalResult runOn(ModuleOp modOp) {
1761 // If the cleanup mode is set to remove anything not reachable from the "Main" struct, do an
1762 // initial pass to remove things that are not reachable (as an optimization) because creating
1763 // an instantiated version of a struct will not cause something to become reachable that was
1764 // not already reachable in parameterized form.
1765 if (cleanupMode == StructCleanupMode::MainAsRoot) {
1766 if (failed(eraseUnreachableFromMainStruct(modOp))) {
1767 return failure();
1768 }
1769 }
1770
1771 {
1772 // Preliminary step: remove empty parameter lists from structs
1773 OpPassManager nestedPM(ModuleOp::getOperationName());
1774 nestedPM.addPass(createEmptyParamListRemoval());
1775 if (failed(runPipeline(nestedPM, modOp))) {
1776 return failure();
1777 }
1778 }
1779
1780 ConversionTracker tracker;
1781 unsigned loopCount = 0;
1782 do {
1783 ++loopCount;
1784 if (loopCount > iterationLimit) {
1785 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit
1786 << " iterations!\n";
1787 return failure();
1788 }
1789 tracker.resetModifiedFlag();
1790
1791 // Find calls to "compute()" that return a parameterized struct and replace it to call a
1792 // flattened version of the struct that has parameters replaced with the constant values.
1793 // Create the necessary instantiated/flattened struct in the same location as the original.
1794 if (failed(Step1_InstantiateStructs::run(modOp, tracker))) {
1795 llvm::errs() << DEBUG_TYPE << " failed while replacing concrete-parameter struct types\n";
1796 return failure();
1797 }
1798
1799 // Unroll loops with known iterations.
1800 if (failed(Step2_Unroll::run(modOp, tracker))) {
1801 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
1802 return failure();
1803 }
1804
1805 // Instantiate affine_map parameters of StructType and ArrayType.
1806 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
1807 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
1808 return failure();
1809 }
1810
1811 // Propagate updated types using the semantics of various ops.
1812 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
1813 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
1814 return failure();
1815 }
1816
1817 LLVM_DEBUG(if (tracker.isModified()) {
1818 llvm::dbgs() << "=====================================================================\n";
1819 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << '\n';
1820 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1821 llvm::dbgs() << "=====================================================================\n";
1822 });
1823 } while (tracker.isModified());
1824
1825 // Perform cleanup according to the 'cleanupMode' option.
1826 switch (cleanupMode) {
1827 case StructCleanupMode::MainAsRoot:
1828 return eraseUnreachableFromMainStruct(modOp, false);
1829 case StructCleanupMode::ConcreteAsRoot:
1830 return eraseUnreachableFromConcreteStructs(modOp);
1831 case StructCleanupMode::Preimage:
1832 return erasePreimageOfInstantiations(modOp, tracker);
1833 case StructCleanupMode::Disabled:
1834 return success();
1835 }
1836 llvm_unreachable("switch cases cover all options");
1837 }
1838
1839 // Erase parameterized structs that were replaced with concrete instantiations.
1840 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod, const ConversionTracker &tracker) {
1841 // TODO: The names from getInstantiatedStructNames() are NOT guaranteed to be paths from the
1842 // "top root" and they also do not indicate a root module so there could be ambiguity. This is a
1843 // broader problem in the FlatteningPass itself so let's just assume, for now, that these are
1844 // paths from the "top root". See [LLZK-286].
1845 Step5_Cleanup::FromEraseSet cleaner(
1846 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
1847 tracker.getInstantiatedStructNames()
1848 );
1849 LogicalResult res = cleaner.eraseUnusedStructs();
1850 if (succeeded(res)) {
1851 // Warn about any structs that were instantiated but still have uses elsewhere.
1852 const SymbolUseGraph *useGraph = nullptr;
1853 rootMod->walk([this, &cleaner, &useGraph](StructDefOp op) {
1854 if (cleaner.getTryToEraseSet().contains(op)) {
1855 // If needed, rebuild use graph to reflect deletions.
1856 if (!useGraph) {
1857 useGraph = &getAnalysis<SymbolUseGraph>();
1858 }
1859 // If the op has any users, report the warning.
1860 if (useGraph->lookupNode(op)->hasPredecessor()) {
1861 op.emitWarning("Parameterized struct still has uses!").report();
1862 }
1863 }
1864 return WalkResult::skip(); // StructDefOp cannot be nested
1865 });
1866 }
1867 return res;
1868 }
1869
1870 LogicalResult eraseUnreachableFromConcreteStructs(ModuleOp rootMod) {
1871 SmallVector<StructDefOp> roots;
1872 rootMod.walk([&roots](StructDefOp op) {
1873 // Note: no need to check if the ConstParamsAttr is empty since `EmptyParamRemovalPass`
1874 // ran earlier.
1875 if (!op.hasConstParamsAttr()) {
1876 roots.push_back(op);
1877 }
1878 return WalkResult::skip(); // StructDefOp cannot be nested
1879 });
1880
1881 Step5_Cleanup::FromKeepSet cleaner(
1882 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1883 );
1884 return cleaner.eraseUnreachableFrom(roots);
1885 }
1886
1887 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod, bool emitWarning = true) {
1888 Step5_Cleanup::FromKeepSet cleaner(
1889 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1890 );
1891 StructDefOp main =
1892 cleaner.tables.getSymbolTable(rootMod).lookup<StructDefOp>(COMPONENT_NAME_MAIN);
1893 if (emitWarning && !main) {
1894 // Emit warning if there is no "Main" because all structs may be removed (only structs that
1895 // are reachable from a global def or free function will be preserved since those constructs
1896 // are not candidate for removal in this pass).
1897 rootMod.emitWarning()
1898 .append(
1899 "using option '", cleanupMode.getArgStr(), '=',
1900 stringifyStructCleanupMode(StructCleanupMode::MainAsRoot), "' with no \"",
1901 COMPONENT_NAME_MAIN, "\" struct may remove all structs!"
1902 )
1903 .report();
1904 }
1905 return cleaner.eraseUnreachableFrom(
1906 main ? ArrayRef<StructDefOp> {main} : ArrayRef<StructDefOp> {}
1907 );
1908 }
1909};
1910
1911} // namespace
1912
1914 return std::make_unique<FlatteningPass>();
1915};
#define DEBUG_TYPE
#define DEBUG_TYPE
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
const SymbolDefTreeNode * getParent() const
Returns the parent node in the tree. The root node will return nullptr.
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool hasPredecessor() const
Return true if this node has any predecessors.
llvm::iterator_range< iterator > predecessorIter() const
Range over predecessor nodes.
Builds a graph structure representing the relationships between symbols and their uses.
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:921
::mlir::TypedValue<::mlir::Type > getRvalue()
Definition Ops.h.inc:1073
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:502
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:332
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:905
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.h.inc:712
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:593
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:900
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1152
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1600
bool hasConstParamsAttr()
Return false iff getConstParamsAttr() returns nullptr
Definition Ops.h.inc:1279
void setConstParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:1204
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::ArrayAttr getParams() const
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:47
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:761
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:784
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:755
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:272
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Returns the argument types of this function.
Definition Ops.h.inc:737
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:357
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:947
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:772
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:769
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:971
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:884
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:443
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:149
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:240
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:269
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:126
std::unique_ptr< mlir::Pass > createFlatteningPass()
std::unique_ptr< mlir::Pass > createEmptyParamListRemoval()
::llvm::StringRef stringifyStructCleanupMode(StructCleanupMode val)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:225
bool isConcreteType(Type type, bool allowStructParams)
constexpr char COMPONENT_NAME_MAIN[]
Symbol name for the main entry point struct/component (if any).
Definition Constants.h:23
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:255
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:259
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
bool isDynamic(IntegerAttr intAttr)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
int64_t fromAPInt(const llvm::APInt &i)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)