253static inline bool tableOffsetIsntSymbol(
FieldReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
260 ConversionTracker &tracker_;
262 SymbolTableCollection symTables;
264 class MappedTypeConverter :
public TypeConverter {
267 const DenseMap<Attribute, Attribute> ¶mNameToValue;
269 inline Attribute convertIfPossible(Attribute a)
const {
270 auto res = this->paramNameToValue.find(a);
271 return (res != this->paramNameToValue.end()) ? res->second : a;
278 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
280 : TypeConverter(), origTy(originalType), newTy(newType),
281 paramNameToValue(paramNameToInstantiatedValue) {
283 addConversion([](Type inputTy) {
return inputTy; });
286 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
289 if (inputTy == this->origTy) {
293 if (ArrayAttr inputTyParams = inputTy.getParams()) {
294 SmallVector<Attribute> updated;
295 for (Attribute a : inputTyParams) {
296 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
297 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
299 updated.push_back(convertIfPossible(a));
303 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
310 addConversion([
this](
ArrayType inputTy) {
312 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
313 if (!dimSizes.empty()) {
314 SmallVector<Attribute> updated;
315 for (Attribute a : dimSizes) {
316 updated.push_back(convertIfPossible(a));
318 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
324 addConversion([
this](
TypeVarType inputTy) -> Type {
326 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
327 Type convertedType = tyAttr.getValue();
332 return convertedType;
340 template <
typename Impl,
typename Op,
typename... HandledAttrs>
341 class SymbolUserHelper :
public OpConversionPattern<Op> {
343 const DenseMap<Attribute, Attribute> ¶mNameToValue;
346 TypeConverter &converter, MLIRContext *ctx,
unsigned Benefit,
347 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
349 : OpConversionPattern<Op>(converter, ctx, Benefit),
350 paramNameToValue(paramNameToInstantiatedValue) {}
353 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
355 virtual Attribute getNameAttr(Op)
const = 0;
357 virtual LogicalResult handleDefaultRewrite(
358 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
360 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
364 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
365 auto res = this->paramNameToValue.find(getNameAttr(op));
366 if (res == this->paramNameToValue.end()) {
367 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] no instantiation for " << op <<
'\n');
370 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
371 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
373 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
374 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
378 return TS.Default([&](Attribute a) {
379 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
385 class ClonedStructConstReadOpPattern
386 :
public SymbolUserHelper<
387 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
388 SmallVector<Diagnostic> &diagnostics;
391 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
394 ClonedStructConstReadOpPattern(
395 TypeConverter &converter, MLIRContext *ctx,
396 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
397 SmallVector<Diagnostic> &instantiationDiagnostics
401 : super(converter, ctx, 2, paramNameToInstantiatedValue),
402 diagnostics(instantiationDiagnostics) {}
406 LogicalResult handleRewrite(
407 Attribute sym,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
409 APInt attrValue = a.getValue();
410 Type origResTy = op.getType();
411 if (llvm::isa<FeltType>(origResTy)) {
413 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
418 if (llvm::isa<IndexType>(origResTy)) {
423 if (origResTy.isSignlessInteger(1)) {
425 if (attrValue.isZero()) {
429 if (!attrValue.isOne()) {
430 Location opLoc = op.getLoc();
431 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
433 if (getContext()->shouldPrintOpOnDiagnostic()) {
434 diag.attachNote(opLoc) <<
"see current operation: " << *op;
436 diag.attachNote(UnknownLoc::get(getContext()))
438 << sym <<
"\" for this call";
439 diagnostics.push_back(std::move(diag));
444 return op->emitOpError().append(
"unexpected result type ", origResTy);
447 LogicalResult handleRewrite(
448 Attribute,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
455 class ClonedStructFieldReadOpPattern
456 :
public SymbolUserHelper<
457 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
459 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
462 ClonedStructFieldReadOpPattern(
463 TypeConverter &converter, MLIRContext *ctx,
464 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
468 : super(converter, ctx, 2, paramNameToInstantiatedValue) {}
470 Attribute getNameAttr(
FieldReadOp op)
const override {
474 template <
typename Attr>
475 LogicalResult handleRewrite(
476 Attribute,
FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
478 rewriter.modifyOpInPlace(op, [&]() {
485 LogicalResult matchAndRewrite(
486 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
488 if (tableOffsetIsntSymbol(op)) {
492 return super::matchAndRewrite(op, adaptor, rewriter);
496 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
498 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.
getDefinition(symTables, rootMod);
500 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
506 MLIRContext *ctx = origStruct.getContext();
509 DenseMap<Attribute, Attribute> paramNameToConcrete;
513 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
516 ArrayAttr reducedParamNameList =
nullptr;
518 ArrayAttr reducedCallerParams =
nullptr;
520 ArrayAttr paramNames = typeAtDef.
getParams();
524 assert(paramNames.size() == typeAtCallerParams.size());
526 SmallVector<Attribute> remainingNames;
527 SmallVector<Attribute> nonConcreteParams;
528 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
529 Attribute next = typeAtCallerParams[i];
530 if (isConcreteAttr<false>(next)) {
531 paramNameToConcrete[paramNames[i]] = next;
532 attrsForInstantiatedNameSuffix.push_back(next);
534 remainingNames.push_back(paramNames[i]);
535 nonConcreteParams.push_back(next);
536 attrsForInstantiatedNameSuffix.push_back(
nullptr);
540 assert(remainingNames.size() == nonConcreteParams.size());
541 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
542 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
544 if (paramNameToConcrete.empty()) {
545 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
548 if (!remainingNames.empty()) {
549 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
550 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
560 typeAtCaller.
getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
566 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>();
567 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
572 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
573 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
574 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
575 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
581 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
582 ConversionTarget target =
586 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
590 patterns.add<ClonedStructConstReadOpPattern>(
591 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
593 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
594 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
595 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
598 return newRemoteType;
602 StructCloner(ConversionTracker &tracker, ModuleOp root)
603 : tracker_(tracker), rootMod(root), symTables() {}
605 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
606 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
607 if (ArrayAttr params = orig.
getParams()) {
608 return genClone(orig, params.getValue());
610 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
615class ParameterizedStructUseTypeConverter :
public TypeConverter {
616 ConversionTracker &tracker_;
620 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
621 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
623 addConversion([](Type inputTy) {
return inputTy; });
627 if (
auto opt = tracker_.getInstantiation(inputTy)) {
633 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
634 if (failed(cloneRes)) {
639 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
640 <<
" as " << newTy <<
'\n'
642 tracker_.recordInstantiation(inputTy, newTy);
646 addConversion([
this](
ArrayType inputTy) {
647 return inputTy.cloneWith(convertType(inputTy.getElementType()));
652class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
653 ConversionTracker &tracker_;
656 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
659 : OpConversionPattern<CallOp>(converter, ctx, 2), tracker_(tracker) {}
661 LogicalResult matchAndRewrite(
662 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
664 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
667 SmallVector<Type> newResultTypes;
668 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
669 return op->emitError(
"Could not convert Op result types.");
672 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
682 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
683 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
684 tracker_.reportDelayedDiagnostics(newStTy, op);
688 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
689 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
693 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
695 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
698 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
704class FieldDefOpPattern :
public OpConversionPattern<FieldDefOp> {
706 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
709 : OpConversionPattern<FieldDefOp>(converter, ctx, 2) {}
711 LogicalResult matchAndRewrite(
712 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
714 LLVM_DEBUG(llvm::dbgs() <<
"[FieldDefOpPattern] FieldDefOp: " << op <<
'\n');
716 Type oldFieldType = op.
getType();
717 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
718 if (oldFieldType == newFieldType) {
722 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.
setType(newFieldType); });
727LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
728 MLIRContext *ctx = modOp.getContext();
729 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
732 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
733 return applyPartialConversion(modOp, target, std::move(patterns));
788std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
789 SmallVector<int64_t> res;
790 for (OpFoldResult ofr : ofrs) {
791 std::optional<int64_t> cv = getConstantIntValue(ofr);
792 if (!cv.has_value()) {
795 res.push_back(cv.value());
800struct AffineMapFolder {
802 OperandRangeRange mapOpGroups;
803 DenseI32ArrayAttr dimsPerGroup;
804 ArrayRef<Attribute> paramsOfStructTy;
808 SmallVector<SmallVector<Value>> mapOpGroups;
809 SmallVector<int32_t> dimsPerGroup;
810 SmallVector<Attribute> paramsOfStructTy;
813 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
814 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
815 return ValueRange(grp);
820 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
821 if (in.mapOpGroups.empty()) {
826 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
827 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
830 for (Attribute sizeAttr : in.paramsOfStructTy) {
831 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
832 ValueRange currMapOps = in.mapOpGroups[idx++];
837 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
839 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
842 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
843 SmallVector<Attribute> result;
844 bool hasPoison =
false;
845 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
846 return rewriter.getIndexAttr(v);
848 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
850 LLVM_DEBUG(op->emitRemark()
852 "Cannot fold affine_map for ", aspect,
" ",
853 out.paramsOfStructTy.size(),
854 " due to divide by 0 or modulus with negative divisor"
859 if (failed(foldResult)) {
860 LLVM_DEBUG(op->emitRemark()
862 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
868 if (result.size() != 1) {
869 LLVM_DEBUG(op->emitRemark()
871 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
872 " produced ", result.size(),
" results but expected 1"
877 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
878 out.paramsOfStructTy.push_back(result[0]);
882 out.mapOpGroups.emplace_back(currMapOps);
883 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
886 out.paramsOfStructTy.push_back(sizeAttr);
888 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
890 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
891 "produced wrong number of dimensions"
899class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
900 ConversionTracker &tracker_;
903 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
904 : OpRewritePattern(ctx), tracker_(tracker) {}
906 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
909 AffineMapFolder::Output out;
910 AffineMapFolder::Input in = {
915 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
920 if (newResultType == oldResultType) {
925 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
927 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
928 << newResultType <<
" in \"" << op <<
"\"\n"
931 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
938class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
939 ConversionTracker &tracker_;
942 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
943 : OpRewritePattern(ctx), tracker_(tracker) {}
945 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
950 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
952 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
959 AffineMapFolder::Output out;
960 AffineMapFolder::Input in = {
965 if (!in.mapOpGroups.empty()) {
967 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
971 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
977 if (callArgTypes.empty()) {
981 SymbolTableCollection tables;
983 if (failed(lookupRes)) {
986 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
990 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
991 "result type params: "
997 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
998 if (newRetTy == oldRetTy) {
1006 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
1007 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1009 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1010 ", but found ", oldRetTy
1014 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1016 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1017 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1019 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1026 inline LogicalResult instantiateViaTargetType(
1027 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1028 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1033 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1035 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1041 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1043 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1045 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1047 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1053 assert(unifies &&
"should have been checked by verifiers");
1056 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1065 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1066 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1067 [&unifications](std::tuple<Attribute, Attribute> p) {
1068 Attribute fromCall = std::get<1>(p);
1071 if (!isConcreteAttr<>(fromCall)) {
1072 Attribute fromTgt = std::get<0>(p);
1074 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1075 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1077 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1078 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1079 if (it != unifications.end()) {
1080 Attribute unifiedAttr = it->second;
1082 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1084 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1093 out.paramsOfStructTy = newReturnStructParams;
1094 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1095 assert(out.mapOpGroups.empty() &&
"post-condition");
1096 assert(out.dimsPerGroup.empty() &&
"post-condition");
1101LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1102 MLIRContext *ctx = modOp.getContext();
1103 RewritePatternSet patterns(ctx);
1105 InstantiateAtCreateArrayOp,
1106 InstantiateAtCallOpCompute
1109 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1117class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1118 ConversionTracker &tracker_;
1121 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1122 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1124 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1126 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1127 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1132 Type newResultElemType =
nullptr;
1133 for (Operation *user : createResult.getUsers()) {
1134 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1135 if (writeOp.getArrRef() != createResult) {
1138 Type writeRValueType = writeOp.getRvalue().getType();
1139 if (writeRValueType == oldResultElemType) {
1142 if (newResultElemType && newResultElemType != writeRValueType) {
1145 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1146 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1150 newResultElemType = writeRValueType;
1153 if (!newResultElemType) {
1157 if (!tracker_.isLegalConversion(
1158 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1163 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1165 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1173LogicalResult updateArrayElemFromArrAccessOp(
1175 PatternRewriter &rewriter
1182 if (oldArrType == newArrType ||
1183 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1186 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1188 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1195class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1196 ConversionTracker &tracker_;
1199 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1200 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1202 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1203 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1207class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1208 ConversionTracker &tracker_;
1211 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1212 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1214 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1215 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1220class UpdateFieldDefTypeFromWrite final :
public OpRewritePattern<FieldDefOp> {
1221 ConversionTracker &tracker_;
1224 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1225 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1227 LogicalResult matchAndRewrite(
FieldDefOp op, PatternRewriter &rewriter)
const override {
1230 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
1234 Type newType =
nullptr;
1236 std::optional<Location> newTypeLoc = std::nullopt;
1237 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1238 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1239 Type writeToType = writeOp.getVal().getType();
1240 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] checking " << writeOp <<
'\n');
1243 newType = writeToType;
1244 newTypeLoc = writeOp.getLoc();
1245 }
else if (writeToType != newType) {
1251 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateFieldDefTypeFromWrite")) {
1252 if (tracker_.isLegalConversion(newType, writeToType,
"UpdateFieldDefTypeFromWrite")) {
1254 newType = writeToType;
1255 newTypeLoc = writeOp.getLoc();
1258 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1262 "' with different value types"
1265 diag.attachNote(*newTypeLoc).append(
"type written here is ", newType);
1267 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1275 if (!newType || newType == op.
getType()) {
1279 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateFieldDefTypeFromWrite")) {
1282 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1283 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] updated type of " << op <<
'\n');
1290SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1291 SmallVector<std::unique_ptr<Region>> newRegions;
1292 for (Region ®ion : op->getRegions()) {
1293 auto newRegion = std::make_unique<Region>();
1294 newRegion->takeBody(region);
1295 newRegions.push_back(std::move(newRegion));
1304class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1305 ConversionTracker &tracker_;
1308 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1309 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1311 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1312 SmallVector<Type, 1> inferredResultTypes;
1313 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1314 LogicalResult result = retTypeFn.inferReturnTypes(
1315 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1316 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1318 if (failed(result)) {
1321 if (op->getResultTypes() == inferredResultTypes) {
1325 if (!tracker_.areLegalConversions(
1326 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1332 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1333 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1334 Operation *newOp = rewriter.create(
1335 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1336 op->getAttrs(), op->getSuccessors(), newRegions
1338 rewriter.replaceOp(op, newOp);
1339 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1345class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1346 ConversionTracker &tracker_;
1349 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1350 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1352 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1353 Region &body = op.getFunctionBody();
1357 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1358 assert(retOp &&
"final op in body region must be return");
1359 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1362 if (oldFuncTy.getResults() == tyFromReturnOp) {
1366 if (!tracker_.areLegalConversions(
1367 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1372 rewriter.modifyOpInPlace(op, [&]() {
1373 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1376 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
1387class UpdateGlobalCallOpTypes final :
public OpRewritePattern<CallOp> {
1388 ConversionTracker &tracker_;
1391 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1392 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1394 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1395 SymbolTableCollection tables;
1397 if (failed(lookupRes)) {
1400 FuncDefOp targetFunc = lookupRes->get();
1405 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
1409 if (!tracker_.areLegalConversions(
1411 "UpdateGlobalCallOpTypes"
1416 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateGlobalCallOpTypes] replaced " << op);
1418 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1425LogicalResult updateFieldRefValFromFieldDef(
1428 SymbolTableCollection tables;
1433 Type oldResultType = op.
getVal().getType();
1434 Type newResultType = def->get().getType();
1435 if (oldResultType == newResultType ||
1436 !tracker.isLegalConversion(oldResultType, newResultType,
"updateFieldRefValFromFieldDef")) {
1439 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
1441 llvm::dbgs() <<
"[updateFieldRefValFromFieldDef] updated value type in " << op <<
'\n'
1449class UpdateFieldReadValFromDef final :
public OpRewritePattern<FieldReadOp> {
1450 ConversionTracker &tracker_;
1453 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1454 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1456 LogicalResult matchAndRewrite(
FieldReadOp op, PatternRewriter &rewriter)
const override {
1457 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1462class UpdateFieldWriteValFromDef final :
public OpRewritePattern<FieldWriteOp> {
1463 ConversionTracker &tracker_;
1466 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1467 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1469 LogicalResult matchAndRewrite(
FieldWriteOp op, PatternRewriter &rewriter)
const override {
1470 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1474LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1475 MLIRContext *ctx = modOp.getContext();
1476 RewritePatternSet patterns(ctx);
1481 UpdateInferredResultTypes,
1483 UpdateGlobalCallOpTypes,
1484 UpdateFuncTypeFromReturn,
1485 UpdateNewArrayElemFromWrite,
1486 UpdateArrayElemFromArrRead,
1487 UpdateArrayElemFromArrWrite,
1488 UpdateFieldDefTypeFromWrite,
1489 UpdateFieldReadValFromDef,
1490 UpdateFieldWriteValFromDef
1493 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1501 SymbolTableCollection tables;
1504 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1512struct FromKeepSet :
public CleanupBase {
1513 using CleanupBase::CleanupBase;
1518 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1520 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1522 rootMod.walk([&roots](Operation *op) {
1526 if (!fdef.isInStruct()) {
1535 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1536 for (
size_t i = 0; i < roots.size(); ++i) {
1537 SymbolOpInterface keepRoot = roots[i];
1538 LLVM_DEBUG({ llvm::dbgs() <<
"[EraseUnreachable] root: " << keepRoot <<
'\n'; });
1540 assert(keepRootNode &&
"every struct def must be in the def tree");
1541 for (
const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1543 llvm::dbgs() <<
"[EraseUnreachable] can reach: " << reachableDefNode->getOp() <<
'\n';
1545 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1550 if (
const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1552 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1554 llvm::dbgs() <<
"[EraseUnreachable] uses symbol: "
1555 << usedSymbolNode->getSymbolPath() <<
'\n';
1559 if (usedSymbolNode->isStructParam()) {
1563 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1564 if (failed(lookupRes)) {
1565 LLVM_DEBUG(useGraph.dumpToDotFile());
1569 if (lookupRes->viaInclude()) {
1572 if (
StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1573 bool insertRes = roots.insert(asStruct);
1576 llvm::dbgs() <<
"[EraseUnreachable] found another root: " << asStruct <<
'\n';
1586 rootMod.walk([
this, &symbolsToKeep](
StructDefOp op) {
1589 if (!symbolsToKeep.contains(n)) {
1590 LLVM_DEBUG(llvm::dbgs() <<
"[EraseUnreachable] removing: " << op.getSymName() <<
'\n');
1594 return WalkResult::skip();
1601struct FromEraseSet :
public CleanupBase {
1606 DenseSet<SymbolRefAttr> &&tryToErasePaths
1608 : CleanupBase(root, symDefTree, symUseGraph) {
1610 for (SymbolRefAttr path : tryToErasePaths) {
1611 Operation *lookupFrom = rootMod.getOperation();
1613 assert(succeeded(res) &&
"inputs must be valid StructDefOp references");
1614 if (!res->viaInclude()) {
1615 tryToErase.insert(res->get());
1620 LogicalResult eraseUnusedStructs() {
1623 collectSafeToErase(sd);
1628 for (
auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1629 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1630 visitedPlusSafetyResult.erase(it);
1633 for (
auto &[sym, _] : visitedPlusSafetyResult) {
1634 LLVM_DEBUG(llvm::dbgs() <<
"[EraseIfUnused] removing: " << sym.getNameAttr() <<
'\n');
1640 const DenseSet<StructDefOp> &getTryToEraseSet()
const {
return tryToErase; }
1644 DenseSet<StructDefOp> tryToErase;
1648 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1650 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1654 bool collectSafeToErase(SymbolOpInterface check) {
1658 auto visited = visitedPlusSafetyResult.find(check);
1659 if (visited != visitedPlusSafetyResult.end()) {
1660 return visited->second;
1664 if (
StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1665 if (!tryToErase.contains(sd)) {
1666 visitedPlusSafetyResult[check] =
false;
1673 visitedPlusSafetyResult[check] =
true;
1677 if (collectSafeToErase(defTree.lookupNode(check))) {
1678 auto useNode = useGraph.lookupNode(check);
1679 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1680 if (!useNode || collectSafeToErase(useNode)) {
1686 visitedPlusSafetyResult[check] =
false;
1694 if (SymbolOpInterface checkOp = p->getOp()) {
1695 return collectSafeToErase(checkOp);
1705 if (SymbolOpInterface checkOp = cachedLookup(p)) {
1706 if (!collectSafeToErase(checkOp)) {
1719 assert(node &&
"must provide a node");
1721 auto fromCache = lookupCache.find(node);
1722 if (fromCache != lookupCache.end()) {
1723 return fromCache->second;
1727 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
1728 assert(lookupRes->get() !=
nullptr &&
"lookup must return an Operation");
1733 SymbolOpInterface actualRes =
1734 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1736 lookupCache[node] = actualRes;
1737 assert((!actualRes == lookupRes->viaInclude()) &&
"not found iff included");