253static inline bool tableOffsetIsntSymbol(
FieldReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
260 ConversionTracker &tracker_;
262 SymbolTableCollection symTables;
264 class MappedTypeConverter :
public TypeConverter {
267 const DenseMap<Attribute, Attribute> ¶mNameToValue;
269 inline Attribute convertIfPossible(Attribute a)
const {
270 auto res = this->paramNameToValue.find(a);
271 return (res != this->paramNameToValue.end()) ? res->second : a;
278 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
280 : TypeConverter(), origTy(originalType), newTy(newType),
281 paramNameToValue(paramNameToInstantiatedValue) {
283 addConversion([](Type inputTy) {
return inputTy; });
286 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
289 if (inputTy == this->origTy) {
293 if (ArrayAttr inputTyParams = inputTy.getParams()) {
294 SmallVector<Attribute> updated;
295 for (Attribute a : inputTyParams) {
296 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
297 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
299 updated.push_back(convertIfPossible(a));
303 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
310 addConversion([
this](
ArrayType inputTy) {
312 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
313 if (!dimSizes.empty()) {
314 SmallVector<Attribute> updated;
315 for (Attribute a : dimSizes) {
316 updated.push_back(convertIfPossible(a));
318 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
324 addConversion([
this](
TypeVarType inputTy) -> Type {
326 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
327 Type convertedType = tyAttr.getValue();
332 return convertedType;
340 template <
typename Impl,
typename Op,
typename... HandledAttrs>
341 class SymbolUserHelper :
public OpConversionPattern<Op> {
343 const DenseMap<Attribute, Attribute> ¶mNameToValue;
346 TypeConverter &converter, MLIRContext *ctx,
unsigned Benefit,
347 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
349 : OpConversionPattern<Op>(converter, ctx, Benefit),
350 paramNameToValue(paramNameToInstantiatedValue) {}
353 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
355 virtual Attribute getNameAttr(Op)
const = 0;
357 virtual LogicalResult handleDefaultRewrite(
358 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
360 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
364 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
365 auto res = this->paramNameToValue.find(getNameAttr(op));
366 if (res == this->paramNameToValue.end()) {
367 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] no instantiation for " << op <<
'\n');
370 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
371 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
373 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
374 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
378 return TS.Default([&](Attribute a) {
379 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
385 class ClonedStructConstReadOpPattern
386 :
public SymbolUserHelper<
387 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
388 SmallVector<Diagnostic> &diagnostics;
391 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
394 ClonedStructConstReadOpPattern(
395 TypeConverter &converter, MLIRContext *ctx,
396 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
397 SmallVector<Diagnostic> &instantiationDiagnostics
401 : super(converter, ctx, 2, paramNameToInstantiatedValue),
402 diagnostics(instantiationDiagnostics) {}
406 LogicalResult handleRewrite(
407 Attribute sym,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
409 APInt attrValue = a.getValue();
410 Type origResTy = op.getType();
411 if (llvm::isa<FeltType>(origResTy)) {
413 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
418 if (llvm::isa<IndexType>(origResTy)) {
423 if (origResTy.isSignlessInteger(1)) {
425 if (attrValue.isZero()) {
429 if (!attrValue.isOne()) {
430 Location opLoc = op.getLoc();
431 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
433 if (getContext()->shouldPrintOpOnDiagnostic()) {
434 diag.attachNote(opLoc) <<
"see current operation: " << *op;
436 diag.attachNote(UnknownLoc::get(getContext()))
438 << sym <<
"\" for this call";
439 diagnostics.push_back(std::move(diag));
444 return op->emitOpError().append(
"unexpected result type ", origResTy);
447 LogicalResult handleRewrite(
448 Attribute,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
455 class ClonedStructFieldReadOpPattern
456 :
public SymbolUserHelper<
457 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
459 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
462 ClonedStructFieldReadOpPattern(
463 TypeConverter &converter, MLIRContext *ctx,
464 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
468 : super(converter, ctx, 2, paramNameToInstantiatedValue) {}
470 Attribute getNameAttr(
FieldReadOp op)
const override {
474 template <
typename Attr>
475 LogicalResult handleRewrite(
476 Attribute,
FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
478 rewriter.modifyOpInPlace(op, [&]() {
485 LogicalResult matchAndRewrite(
486 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
488 if (tableOffsetIsntSymbol(op)) {
492 return super::matchAndRewrite(op, adaptor, rewriter);
496 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
498 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.
getDefinition(symTables, rootMod);
500 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
506 MLIRContext *ctx = origStruct.getContext();
509 DenseMap<Attribute, Attribute> paramNameToConcrete;
513 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
516 ArrayAttr reducedParamNameList =
nullptr;
518 ArrayAttr reducedCallerParams =
nullptr;
520 ArrayAttr paramNames = typeAtDef.
getParams();
524 assert(paramNames.size() == typeAtCallerParams.size());
526 SmallVector<Attribute> remainingNames;
527 SmallVector<Attribute> nonConcreteParams;
528 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
529 Attribute next = typeAtCallerParams[i];
530 if (isConcreteAttr<false>(next)) {
531 paramNameToConcrete[paramNames[i]] = next;
532 attrsForInstantiatedNameSuffix.push_back(next);
534 remainingNames.push_back(paramNames[i]);
535 nonConcreteParams.push_back(next);
536 attrsForInstantiatedNameSuffix.push_back(
nullptr);
540 assert(remainingNames.size() == nonConcreteParams.size());
541 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
542 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
544 if (paramNameToConcrete.empty()) {
545 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
548 if (!remainingNames.empty()) {
549 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
550 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
560 typeAtCaller.
getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
566 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>();
567 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
572 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
573 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
574 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
575 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
581 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
582 ConversionTarget target =
586 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
590 patterns.add<ClonedStructConstReadOpPattern>(
591 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
593 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
594 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
595 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
598 return newRemoteType;
602 StructCloner(ConversionTracker &tracker, ModuleOp root)
603 : tracker_(tracker), rootMod(root), symTables() {}
605 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
606 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
607 if (ArrayAttr params = orig.
getParams()) {
608 return genClone(orig, params.getValue());
610 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
615class ParameterizedStructUseTypeConverter :
public TypeConverter {
616 ConversionTracker &tracker_;
620 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
621 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
623 addConversion([](Type inputTy) {
return inputTy; });
627 if (
auto opt = tracker_.getInstantiation(inputTy)) {
633 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
634 if (failed(cloneRes)) {
639 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
640 <<
" as " << newTy <<
'\n'
642 tracker_.recordInstantiation(inputTy, newTy);
646 addConversion([
this](
ArrayType inputTy) {
647 return inputTy.cloneWith(convertType(inputTy.getElementType()));
652class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
653 ConversionTracker &tracker_;
656 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
659 : OpConversionPattern<CallOp>(converter, ctx, 2), tracker_(tracker) {}
661 LogicalResult matchAndRewrite(
662 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
664 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
667 SmallVector<Type> newResultTypes;
668 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
669 return op->emitError(
"Could not convert Op result types.");
672 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
682 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
683 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
684 tracker_.reportDelayedDiagnostics(newStTy, op);
688 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
689 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
693 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
695 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
699 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
705class FieldDefOpPattern :
public OpConversionPattern<FieldDefOp> {
707 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
710 : OpConversionPattern<FieldDefOp>(converter, ctx, 2) {}
712 LogicalResult matchAndRewrite(
713 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
715 LLVM_DEBUG(llvm::dbgs() <<
"[FieldDefOpPattern] FieldDefOp: " << op <<
'\n');
717 Type oldFieldType = op.
getType();
718 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
719 if (oldFieldType == newFieldType) {
723 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.
setType(newFieldType); });
728LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
729 MLIRContext *ctx = modOp.getContext();
730 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
733 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
734 return applyPartialConversion(modOp, target, std::move(patterns));
789std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
790 SmallVector<int64_t> res;
791 for (OpFoldResult ofr : ofrs) {
792 std::optional<int64_t> cv = getConstantIntValue(ofr);
793 if (!cv.has_value()) {
796 res.push_back(cv.value());
801struct AffineMapFolder {
803 OperandRangeRange mapOpGroups;
804 DenseI32ArrayAttr dimsPerGroup;
805 ArrayRef<Attribute> paramsOfStructTy;
809 SmallVector<SmallVector<Value>> mapOpGroups;
810 SmallVector<int32_t> dimsPerGroup;
811 SmallVector<Attribute> paramsOfStructTy;
814 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
815 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
816 return ValueRange(grp);
821 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
822 if (in.mapOpGroups.empty()) {
827 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
828 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
831 for (Attribute sizeAttr : in.paramsOfStructTy) {
832 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
833 ValueRange currMapOps = in.mapOpGroups[idx++];
838 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
840 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
843 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
844 SmallVector<Attribute> result;
845 bool hasPoison =
false;
846 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
847 return rewriter.getIndexAttr(v);
849 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
854 "Cannot fold affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
855 " due to divide by 0 or modulus with negative divisor"
860 if (failed(foldResult)) {
864 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" failed"
869 if (result.size() != 1) {
873 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
874 " produced ", result.size(),
" results but expected 1"
879 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
880 out.paramsOfStructTy.push_back(result[0]);
884 out.mapOpGroups.emplace_back(currMapOps);
885 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
888 out.paramsOfStructTy.push_back(sizeAttr);
890 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
892 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
893 "produced wrong number of dimensions"
901class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
903 ConversionTracker &tracker_;
906 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
907 : OpRewritePattern(ctx), tracker_(tracker) {}
909 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
912 AffineMapFolder::Output out;
913 AffineMapFolder::Input in = {
918 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
923 if (newResultType == oldResultType) {
928 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
930 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
931 << newResultType <<
" in \"" << op <<
"\"\n"
934 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
941class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
942 ConversionTracker &tracker_;
945 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
946 : OpRewritePattern(ctx), tracker_(tracker) {}
948 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
953 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
955 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
962 AffineMapFolder::Output out;
963 AffineMapFolder::Input in = {
968 if (!in.mapOpGroups.empty()) {
970 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
974 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
980 if (callArgTypes.empty()) {
984 SymbolTableCollection tables;
986 if (failed(lookupRes)) {
989 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
993 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
994 "result type params: "
1000 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
1001 if (newRetTy == oldRetTy) {
1009 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
1010 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1012 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1013 ", but found ", oldRetTy
1017 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1019 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1020 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1023 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1030 inline LogicalResult instantiateViaTargetType(
1031 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1032 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1037 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1039 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1045 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1047 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1049 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1051 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1058 assert(unifies &&
"should have been checked by verifiers");
1061 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1070 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1071 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1072 [&unifications](std::tuple<Attribute, Attribute> p) {
1073 Attribute fromCall = std::get<1>(p);
1076 if (!isConcreteAttr<>(fromCall)) {
1077 Attribute fromTgt = std::get<0>(p);
1079 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1080 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1082 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1083 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1084 if (it != unifications.end()) {
1085 Attribute unifiedAttr = it->second;
1087 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1089 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1098 out.paramsOfStructTy = newReturnStructParams;
1099 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1100 assert(out.mapOpGroups.empty() &&
"post-condition");
1101 assert(out.dimsPerGroup.empty() &&
"post-condition");
1106LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1107 MLIRContext *ctx = modOp.getContext();
1108 RewritePatternSet patterns(ctx);
1110 InstantiateAtCreateArrayOp,
1111 InstantiateAtCallOpCompute
1114 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1122class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1123 ConversionTracker &tracker_;
1126 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1127 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1129 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1131 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1132 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1137 Type newResultElemType =
nullptr;
1138 for (Operation *user : createResult.getUsers()) {
1139 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1140 if (writeOp.getArrRef() != createResult) {
1143 Type writeRValueType = writeOp.getRvalue().getType();
1144 if (writeRValueType == oldResultElemType) {
1147 if (newResultElemType && newResultElemType != writeRValueType) {
1150 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1151 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1155 newResultElemType = writeRValueType;
1158 if (!newResultElemType) {
1162 if (!tracker_.isLegalConversion(
1163 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1168 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1170 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1178LogicalResult updateArrayElemFromArrAccessOp(
1180 PatternRewriter &rewriter
1187 if (oldArrType == newArrType ||
1188 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1191 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1193 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1200class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1201 ConversionTracker &tracker_;
1204 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1205 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1207 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1208 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1212class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1213 ConversionTracker &tracker_;
1216 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1217 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1219 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1220 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1225class UpdateFieldDefTypeFromWrite final :
public OpRewritePattern<FieldDefOp> {
1226 ConversionTracker &tracker_;
1229 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1230 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1232 LogicalResult matchAndRewrite(
FieldDefOp op, PatternRewriter &rewriter)
const override {
1235 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
1239 Type newType =
nullptr;
1241 std::optional<Location> newTypeLoc = std::nullopt;
1242 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1243 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1244 Type writeToType = writeOp.getVal().getType();
1245 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] checking " << writeOp <<
'\n');
1248 newType = writeToType;
1249 newTypeLoc = writeOp.getLoc();
1250 }
else if (writeToType != newType) {
1256 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateFieldDefTypeFromWrite")) {
1257 if (tracker_.isLegalConversion(newType, writeToType,
"UpdateFieldDefTypeFromWrite")) {
1259 newType = writeToType;
1260 newTypeLoc = writeOp.getLoc();
1263 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1267 "' with different value types"
1270 diag.attachNote(*newTypeLoc).append(
"type written here is ", newType);
1272 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1280 if (!newType || newType == op.
getType()) {
1284 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateFieldDefTypeFromWrite")) {
1287 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1288 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] updated type of " << op <<
'\n');
1295SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1296 SmallVector<std::unique_ptr<Region>> newRegions;
1297 for (Region ®ion : op->getRegions()) {
1298 auto newRegion = std::make_unique<Region>();
1299 newRegion->takeBody(region);
1300 newRegions.push_back(std::move(newRegion));
1309class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1310 ConversionTracker &tracker_;
1313 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1314 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1316 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1317 SmallVector<Type, 1> inferredResultTypes;
1318 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1319 LogicalResult result = retTypeFn.inferReturnTypes(
1320 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1321 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1323 if (failed(result)) {
1326 if (op->getResultTypes() == inferredResultTypes) {
1330 if (!tracker_.areLegalConversions(
1331 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1337 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1338 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1339 Operation *newOp = rewriter.create(
1340 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1341 op->getAttrs(), op->getSuccessors(), newRegions
1343 rewriter.replaceOp(op, newOp);
1344 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1350class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1351 ConversionTracker &tracker_;
1354 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1355 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1357 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1358 Region &body = op.getFunctionBody();
1362 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1363 assert(retOp &&
"final op in body region must be return");
1364 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1367 if (oldFuncTy.getResults() == tyFromReturnOp) {
1371 if (!tracker_.areLegalConversions(
1372 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1377 rewriter.modifyOpInPlace(op, [&]() {
1378 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1381 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
1392class UpdateGlobalCallOpTypes final :
public OpRewritePattern<CallOp> {
1393 ConversionTracker &tracker_;
1396 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1397 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1399 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1400 SymbolTableCollection tables;
1402 if (failed(lookupRes)) {
1405 FuncDefOp targetFunc = lookupRes->get();
1410 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
1414 if (!tracker_.areLegalConversions(
1416 "UpdateGlobalCallOpTypes"
1421 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateGlobalCallOpTypes] replaced " << op);
1424 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1431LogicalResult updateFieldRefValFromFieldDef(
1434 SymbolTableCollection tables;
1439 Type oldResultType = op.
getVal().getType();
1440 Type newResultType = def->get().getType();
1441 if (oldResultType == newResultType ||
1442 !tracker.isLegalConversion(oldResultType, newResultType,
"updateFieldRefValFromFieldDef")) {
1445 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
1447 llvm::dbgs() <<
"[updateFieldRefValFromFieldDef] updated value type in " << op <<
'\n'
1455class UpdateFieldReadValFromDef final :
public OpRewritePattern<FieldReadOp> {
1456 ConversionTracker &tracker_;
1459 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1460 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1462 LogicalResult matchAndRewrite(
FieldReadOp op, PatternRewriter &rewriter)
const override {
1463 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1468class UpdateFieldWriteValFromDef final :
public OpRewritePattern<FieldWriteOp> {
1469 ConversionTracker &tracker_;
1472 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1473 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1475 LogicalResult matchAndRewrite(
FieldWriteOp op, PatternRewriter &rewriter)
const override {
1476 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1480LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1481 MLIRContext *ctx = modOp.getContext();
1482 RewritePatternSet patterns(ctx);
1487 UpdateInferredResultTypes,
1489 UpdateGlobalCallOpTypes,
1490 UpdateFuncTypeFromReturn,
1491 UpdateNewArrayElemFromWrite,
1492 UpdateArrayElemFromArrRead,
1493 UpdateArrayElemFromArrWrite,
1494 UpdateFieldDefTypeFromWrite,
1495 UpdateFieldReadValFromDef,
1496 UpdateFieldWriteValFromDef
1499 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1507 SymbolTableCollection tables;
1510 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1518struct FromKeepSet :
public CleanupBase {
1519 using CleanupBase::CleanupBase;
1524 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1526 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1528 rootMod.walk([&roots](Operation *op) {
1532 if (!fdef.isInStruct()) {
1541 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1542 for (
size_t i = 0; i < roots.size(); ++i) {
1543 SymbolOpInterface keepRoot = roots[i];
1544 LLVM_DEBUG({ llvm::dbgs() <<
"[EraseUnreachable] root: " << keepRoot <<
'\n'; });
1546 assert(keepRootNode &&
"every struct def must be in the def tree");
1547 for (
const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1549 llvm::dbgs() <<
"[EraseUnreachable] can reach: " << reachableDefNode->getOp() <<
'\n';
1551 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1556 if (
const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1558 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1560 llvm::dbgs() <<
"[EraseUnreachable] uses symbol: "
1561 << usedSymbolNode->getSymbolPath() <<
'\n';
1565 if (usedSymbolNode->isStructParam()) {
1569 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1570 if (failed(lookupRes)) {
1571 LLVM_DEBUG(useGraph.dumpToDotFile());
1575 if (lookupRes->viaInclude()) {
1578 if (
StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1579 bool insertRes = roots.insert(asStruct);
1583 llvm::dbgs() <<
"[EraseUnreachable] found another root: " << asStruct <<
'\n';
1593 rootMod.walk([
this, &symbolsToKeep](
StructDefOp op) {
1596 if (!symbolsToKeep.contains(n)) {
1597 LLVM_DEBUG(llvm::dbgs() <<
"[EraseUnreachable] removing: " << op.getSymName() <<
'\n');
1601 return WalkResult::skip();
1608struct FromEraseSet :
public CleanupBase {
1613 DenseSet<SymbolRefAttr> &&tryToErasePaths
1615 : CleanupBase(root, symDefTree, symUseGraph) {
1617 for (SymbolRefAttr path : tryToErasePaths) {
1618 Operation *lookupFrom = rootMod.getOperation();
1620 assert(succeeded(res) &&
"inputs must be valid StructDefOp references");
1621 if (!res->viaInclude()) {
1622 tryToErase.insert(res->get());
1627 LogicalResult eraseUnusedStructs() {
1630 collectSafeToErase(sd);
1635 for (
auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1636 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1637 visitedPlusSafetyResult.erase(it);
1640 for (
auto &[sym, _] : visitedPlusSafetyResult) {
1641 LLVM_DEBUG(llvm::dbgs() <<
"[EraseIfUnused] removing: " << sym.getNameAttr() <<
'\n');
1647 const DenseSet<StructDefOp> &getTryToEraseSet()
const {
return tryToErase; }
1651 DenseSet<StructDefOp> tryToErase;
1655 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1657 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1661 bool collectSafeToErase(SymbolOpInterface check) {
1665 auto visited = visitedPlusSafetyResult.find(check);
1666 if (visited != visitedPlusSafetyResult.end()) {
1667 return visited->second;
1671 if (
StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1672 if (!tryToErase.contains(sd)) {
1673 visitedPlusSafetyResult[check] =
false;
1680 visitedPlusSafetyResult[check] =
true;
1684 if (collectSafeToErase(defTree.lookupNode(check))) {
1685 auto useNode = useGraph.lookupNode(check);
1686 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1687 if (!useNode || collectSafeToErase(useNode)) {
1693 visitedPlusSafetyResult[check] =
false;
1701 if (SymbolOpInterface checkOp = p->getOp()) {
1702 return collectSafeToErase(checkOp);
1712 if (SymbolOpInterface checkOp = cachedLookup(p)) {
1713 if (!collectSafeToErase(checkOp)) {
1726 assert(node &&
"must provide a node");
1728 auto fromCache = lookupCache.find(node);
1729 if (fromCache != lookupCache.end()) {
1730 return fromCache->second;
1734 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
1735 assert(lookupRes->get() !=
nullptr &&
"lookup must return an Operation");
1740 SymbolOpInterface actualRes =
1741 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1743 lookupCache[node] = actualRes;
1744 assert((!actualRes == lookupRes->viaInclude()) &&
"not found iff included");