248static inline bool tableOffsetIsntSymbol(
FieldReadOp op) {
249 return !mlir::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
255 ConversionTracker &tracker_;
257 SymbolTableCollection symTables;
259 class MappedTypeConverter :
public TypeConverter {
262 const DenseMap<Attribute, Attribute> ¶mNameToValue;
264 inline Attribute convertIfPossible(Attribute a)
const {
265 auto res = this->paramNameToValue.find(a);
266 return (res != this->paramNameToValue.end()) ? res->second : a;
273 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
275 : TypeConverter(), origTy(originalType), newTy(newType),
276 paramNameToValue(paramNameToInstantiatedValue) {
278 addConversion([](Type inputTy) {
return inputTy; });
281 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
284 if (inputTy == this->origTy) {
288 if (ArrayAttr inputTyParams = inputTy.getParams()) {
289 SmallVector<Attribute> updated;
290 for (Attribute a : inputTyParams) {
291 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
292 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
294 updated.push_back(convertIfPossible(a));
298 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
305 addConversion([
this](
ArrayType inputTy) {
307 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
308 if (!dimSizes.empty()) {
309 SmallVector<Attribute> updated;
310 for (Attribute a : dimSizes) {
311 updated.push_back(convertIfPossible(a));
313 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
319 addConversion([
this](
TypeVarType inputTy) -> Type {
321 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
322 Type convertedType = tyAttr.getValue();
327 return convertedType;
335 template <
typename Impl,
typename Op,
typename... HandledAttrs>
336 class SymbolUserHelper :
public OpConversionPattern<Op> {
338 const DenseMap<Attribute, Attribute> ¶mNameToValue;
341 TypeConverter &converter, MLIRContext *ctx,
unsigned Benefit,
342 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
344 : OpConversionPattern<Op>(converter, ctx, Benefit),
345 paramNameToValue(paramNameToInstantiatedValue) {}
348 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
350 virtual Attribute getNameAttr(Op)
const = 0;
352 virtual LogicalResult handleDefaultRewrite(
353 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
355 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
359 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
360 auto res = this->paramNameToValue.find(getNameAttr(op));
361 if (res == this->paramNameToValue.end()) {
362 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] no instantiation for " << op <<
'\n');
365 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
366 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
368 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
369 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
373 return TS.Default([&](Attribute a) {
374 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
380 class ClonedStructConstReadOpPattern
381 :
public SymbolUserHelper<
382 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
383 SmallVector<Diagnostic> &diagnostics;
386 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
389 ClonedStructConstReadOpPattern(
390 TypeConverter &converter, MLIRContext *ctx,
391 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
392 SmallVector<Diagnostic> &instantiationDiagnostics
396 : super(converter, ctx, 2, paramNameToInstantiatedValue),
397 diagnostics(instantiationDiagnostics) {}
401 LogicalResult handleRewrite(
402 Attribute sym,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
404 APInt attrValue = a.getValue();
405 Type origResTy = op.getType();
406 if (llvm::isa<FeltType>(origResTy)) {
408 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
413 if (llvm::isa<IndexType>(origResTy)) {
418 if (origResTy.isSignlessInteger(1)) {
420 if (attrValue.isZero()) {
424 if (!attrValue.isOne()) {
425 Location opLoc = op.getLoc();
426 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
428 if (getContext()->shouldPrintOpOnDiagnostic()) {
429 diag.attachNote(opLoc) <<
"see current operation: " << *op;
431 diag.attachNote(UnknownLoc::get(getContext()))
433 << sym <<
"\" for this call";
434 diagnostics.push_back(std::move(diag));
439 return op->emitOpError().append(
"unexpected result type ", origResTy);
442 LogicalResult handleRewrite(
443 Attribute,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
450 class ClonedStructFieldReadOpPattern
451 :
public SymbolUserHelper<
452 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
454 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
457 ClonedStructFieldReadOpPattern(
458 TypeConverter &converter, MLIRContext *ctx,
459 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
463 : super(converter, ctx, 2, paramNameToInstantiatedValue) {}
465 Attribute getNameAttr(
FieldReadOp op)
const override {
469 template <
typename Attr>
470 LogicalResult handleRewrite(
471 Attribute,
FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
473 rewriter.modifyOpInPlace(op, [&]() {
480 LogicalResult matchAndRewrite(
481 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
483 if (tableOffsetIsntSymbol(op)) {
487 return super::matchAndRewrite(op, adaptor, rewriter);
491 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
493 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.
getDefinition(symTables, rootMod);
495 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
501 MLIRContext *ctx = origStruct.getContext();
504 DenseMap<Attribute, Attribute> paramNameToConcrete;
508 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
511 ArrayAttr reducedParamNameList =
nullptr;
513 ArrayAttr reducedCallerParams =
nullptr;
515 ArrayAttr paramNames = typeAtDef.
getParams();
519 assert(paramNames.size() == typeAtCallerParams.size());
521 SmallVector<Attribute> remainingNames;
522 SmallVector<Attribute> nonConcreteParams;
523 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
524 Attribute next = typeAtCallerParams[i];
525 if (isConcreteAttr<false>(next)) {
526 paramNameToConcrete[paramNames[i]] = next;
527 attrsForInstantiatedNameSuffix.push_back(next);
529 remainingNames.push_back(paramNames[i]);
530 nonConcreteParams.push_back(next);
531 attrsForInstantiatedNameSuffix.push_back(
nullptr);
535 assert(remainingNames.size() == nonConcreteParams.size());
536 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
537 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
539 if (paramNameToConcrete.empty()) {
540 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
543 if (!remainingNames.empty()) {
544 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
545 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
554 typeAtCaller.
getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
559 ModuleOp parentModule = llvm::cast<ModuleOp>(origStruct.getParentOp());
560 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
565 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
566 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
567 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
568 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
574 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
575 ConversionTarget target =
579 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
583 patterns.add<ClonedStructConstReadOpPattern>(
584 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
586 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
587 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
588 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
591 return newRemoteType;
595 StructCloner(ConversionTracker &tracker, ModuleOp root)
596 : tracker_(tracker), rootMod(root), symTables() {}
598 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
599 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
600 if (ArrayAttr params = orig.
getParams()) {
601 return genClone(orig, params.getValue());
603 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
608class ParameterizedStructUseTypeConverter :
public TypeConverter {
609 ConversionTracker &tracker_;
613 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
614 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
616 addConversion([](Type inputTy) {
return inputTy; });
620 if (
auto opt = tracker_.getInstantiation(inputTy)) {
626 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
627 if (failed(cloneRes)) {
632 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
633 <<
" as " << newTy <<
'\n'
635 tracker_.recordInstantiation(inputTy, newTy);
639 addConversion([
this](
ArrayType inputTy) {
640 return inputTy.cloneWith(convertType(inputTy.getElementType()));
645class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
646 ConversionTracker &tracker_;
649 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
652 : OpConversionPattern<CallOp>(converter, ctx, 2), tracker_(tracker) {}
654 LogicalResult matchAndRewrite(
CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter)
656 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
659 SmallVector<Type> newResultTypes;
660 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
661 return op->emitError(
"Could not convert Op result types.");
664 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
674 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
675 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
676 tracker_.reportDelayedDiagnostics(newStTy, op);
680 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
681 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
685 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
687 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
690 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
696class FieldDefOpPattern :
public OpConversionPattern<FieldDefOp> {
698 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
701 : OpConversionPattern<FieldDefOp>(converter, ctx, 2) {}
703 LogicalResult matchAndRewrite(
704 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
706 LLVM_DEBUG(llvm::dbgs() <<
"[FieldDefOpPattern] FieldDefOp: " << op <<
'\n');
708 Type oldFieldType = op.
getType();
709 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
710 if (oldFieldType == newFieldType) {
714 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.
setType(newFieldType); });
719LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
720 MLIRContext *ctx = modOp.getContext();
721 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
724 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
725 return applyPartialConversion(modOp, target, std::move(patterns));
780std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
781 SmallVector<int64_t> res;
782 for (OpFoldResult ofr : ofrs) {
783 std::optional<int64_t> cv = getConstantIntValue(ofr);
784 if (!cv.has_value()) {
787 res.push_back(cv.value());
792struct AffineMapFolder {
794 OperandRangeRange mapOpGroups;
795 DenseI32ArrayAttr dimsPerGroup;
796 ArrayRef<Attribute> paramsOfStructTy;
800 SmallVector<SmallVector<Value>> mapOpGroups;
801 SmallVector<int32_t> dimsPerGroup;
802 SmallVector<Attribute> paramsOfStructTy;
805 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
806 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
807 return ValueRange(grp);
812 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
813 if (in.mapOpGroups.empty()) {
818 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
819 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
822 for (Attribute sizeAttr : in.paramsOfStructTy) {
823 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
824 ValueRange currMapOps = in.mapOpGroups[idx++];
829 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
831 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
834 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
835 SmallVector<Attribute> result;
836 bool hasPoison =
false;
837 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
838 return rewriter.getIndexAttr(v);
840 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
842 LLVM_DEBUG(op->emitRemark().append(
843 "Cannot fold affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
844 " due to divide by 0 or modulus with negative divisor"
848 if (failed(foldResult)) {
849 LLVM_DEBUG(op->emitRemark().append(
850 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" failed"
854 if (result.size() != 1) {
855 LLVM_DEBUG(op->emitRemark().append(
856 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" produced ",
857 result.size(),
" results but expected 1"
861 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
862 out.paramsOfStructTy.push_back(result[0]);
866 out.mapOpGroups.emplace_back(currMapOps);
867 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
870 out.paramsOfStructTy.push_back(sizeAttr);
872 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
874 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
875 "produced wrong number of dimensions"
883class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
884 ConversionTracker &tracker_;
887 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
888 : OpRewritePattern(ctx), tracker_(tracker) {}
890 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
893 AffineMapFolder::Output out;
894 AffineMapFolder::Input in = {
899 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
904 if (newResultType == oldResultType) {
909 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
911 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
912 << newResultType <<
" in \"" << op <<
"\"\n"
915 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
922class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
923 ConversionTracker &tracker_;
926 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
927 : OpRewritePattern(ctx), tracker_(tracker) {}
929 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
934 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
936 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
943 AffineMapFolder::Output out;
944 AffineMapFolder::Input in = {
949 if (!in.mapOpGroups.empty()) {
951 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
955 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
961 if (callArgTypes.empty()) {
965 SymbolTableCollection tables;
967 if (failed(lookupRes)) {
970 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
974 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
975 "result type params: "
981 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
982 if (newRetTy == oldRetTy) {
990 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
991 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
993 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
994 ", but found ", oldRetTy
998 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1000 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1001 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1003 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1010 inline LogicalResult instantiateViaTargetType(
1011 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1012 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1017 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1019 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1025 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1027 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1029 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1031 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1037 assert(unifies &&
"should have been checked by verifiers");
1040 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1049 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1050 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1051 [&unifications](std::tuple<Attribute, Attribute> p) {
1052 Attribute fromCall = std::get<1>(p);
1055 if (!isConcreteAttr<>(fromCall)) {
1056 Attribute fromTgt = std::get<0>(p);
1058 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1059 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1061 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1062 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1063 if (it != unifications.end()) {
1064 Attribute unifiedAttr = it->second;
1066 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1068 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1077 out.paramsOfStructTy = newReturnStructParams;
1078 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1079 assert(out.mapOpGroups.empty() &&
"post-condition");
1080 assert(out.dimsPerGroup.empty() &&
"post-condition");
1085LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1086 MLIRContext *ctx = modOp.getContext();
1087 RewritePatternSet patterns(ctx);
1089 InstantiateAtCreateArrayOp,
1090 InstantiateAtCallOpCompute
1093 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1101class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1102 ConversionTracker &tracker_;
1105 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1106 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1108 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1110 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1111 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1116 Type newResultElemType =
nullptr;
1117 for (Operation *user : createResult.getUsers()) {
1118 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1119 if (writeOp.getArrRef() != createResult) {
1122 Type writeRValueType = writeOp.getRvalue().getType();
1123 if (writeRValueType == oldResultElemType) {
1126 if (newResultElemType && newResultElemType != writeRValueType) {
1129 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1130 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1134 newResultElemType = writeRValueType;
1137 if (!newResultElemType) {
1141 if (!tracker_.isLegalConversion(
1142 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1147 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1149 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1157LogicalResult updateArrayElemFromArrAccessOp(
1159 PatternRewriter &rewriter
1166 if (oldArrType == newArrType ||
1167 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1170 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1172 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1179class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1180 ConversionTracker &tracker_;
1183 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1184 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1186 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1187 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1191class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1192 ConversionTracker &tracker_;
1195 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1196 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1198 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1199 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1204class UpdateFieldDefTypeFromWrite final :
public OpRewritePattern<FieldDefOp> {
1205 ConversionTracker &tracker_;
1208 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1209 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1211 LogicalResult matchAndRewrite(
FieldDefOp op, PatternRewriter &rewriter)
const override {
1214 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
1218 Type newType =
nullptr;
1219 if (
auto fieldUsers = SymbolTable::getSymbolUses(op, parentRes.value())) {
1220 std::optional<Location> newTypeLoc = std::nullopt;
1221 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1222 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1223 Type writeToType = writeOp.getVal().getType();
1224 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] checking " << writeOp <<
'\n');
1227 newType = writeToType;
1228 newTypeLoc = writeOp.getLoc();
1229 }
else if (writeToType != newType) {
1235 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateFieldDefTypeFromWrite")) {
1236 if (tracker_.isLegalConversion(newType, writeToType,
"UpdateFieldDefTypeFromWrite")) {
1238 newType = writeToType;
1239 newTypeLoc = writeOp.getLoc();
1242 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1246 "' with different value types"
1249 diag.attachNote(*newTypeLoc).append(
"type written here is ", newType);
1251 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1259 if (!newType || newType == op.
getType()) {
1263 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateFieldDefTypeFromWrite")) {
1266 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1267 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] updated type of " << op <<
'\n');
1274SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1275 SmallVector<std::unique_ptr<Region>> newRegions;
1276 for (Region ®ion : op->getRegions()) {
1277 auto newRegion = std::make_unique<Region>();
1278 newRegion->takeBody(region);
1279 newRegions.push_back(std::move(newRegion));
1288class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1289 ConversionTracker &tracker_;
1292 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1293 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1295 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1296 SmallVector<Type, 1> inferredResultTypes;
1297 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1298 LogicalResult result = retTypeFn.inferReturnTypes(
1299 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1300 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1302 if (failed(result)) {
1305 if (op->getResultTypes() == inferredResultTypes) {
1309 if (!tracker_.areLegalConversions(
1310 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1316 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1317 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1318 Operation *newOp = rewriter.create(
1319 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1320 op->getAttrs(), op->getSuccessors(), newRegions
1322 rewriter.replaceOp(op, newOp);
1323 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1329class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1330 ConversionTracker &tracker_;
1333 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1334 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1336 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1337 Region &body = op.getFunctionBody();
1341 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1342 assert(retOp &&
"final op in body region must be return");
1343 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1346 if (oldFuncTy.getResults() == tyFromReturnOp) {
1350 if (!tracker_.areLegalConversions(
1351 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1356 rewriter.modifyOpInPlace(op, [&]() {
1357 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1360 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
1371class UpdateGlobalCallOpTypes final :
public OpRewritePattern<CallOp> {
1372 ConversionTracker &tracker_;
1375 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1376 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1378 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1379 SymbolTableCollection tables;
1381 if (failed(lookupRes)) {
1384 FuncDefOp targetFunc = lookupRes->get();
1389 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
1393 if (!tracker_.areLegalConversions(
1395 "UpdateGlobalCallOpTypes"
1400 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateGlobalCallOpTypes] replaced " << op);
1402 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1409LogicalResult updateFieldRefValFromFieldDef(
1412 SymbolTableCollection tables;
1417 Type oldResultType = op.
getVal().getType();
1418 Type newResultType = def->get().getType();
1419 if (oldResultType == newResultType ||
1420 !tracker.isLegalConversion(oldResultType, newResultType,
"updateFieldRefValFromFieldDef")) {
1423 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
1425 llvm::dbgs() <<
"[updateFieldRefValFromFieldDef] updated value type in " << op <<
'\n'
1433class UpdateFieldReadValFromDef final :
public OpRewritePattern<FieldReadOp> {
1434 ConversionTracker &tracker_;
1437 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1438 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1440 LogicalResult matchAndRewrite(
FieldReadOp op, PatternRewriter &rewriter)
const override {
1441 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1446class UpdateFieldWriteValFromDef final :
public OpRewritePattern<FieldWriteOp> {
1447 ConversionTracker &tracker_;
1450 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1451 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1453 LogicalResult matchAndRewrite(
FieldWriteOp op, PatternRewriter &rewriter)
const override {
1454 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1458LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1459 MLIRContext *ctx = modOp.getContext();
1460 RewritePatternSet patterns(ctx);
1465 UpdateInferredResultTypes,
1467 UpdateGlobalCallOpTypes,
1468 UpdateFuncTypeFromReturn,
1469 UpdateNewArrayElemFromWrite,
1470 UpdateArrayElemFromArrRead,
1471 UpdateArrayElemFromArrWrite,
1472 UpdateFieldDefTypeFromWrite,
1473 UpdateFieldReadValFromDef,
1474 UpdateFieldWriteValFromDef
1477 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));