253static inline bool tableOffsetIsntSymbol(
FieldReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
260 ConversionTracker &tracker_;
262 SymbolTableCollection symTables;
264 class MappedTypeConverter :
public TypeConverter {
267 const DenseMap<Attribute, Attribute> ¶mNameToValue;
269 inline Attribute convertIfPossible(Attribute a)
const {
270 auto res = this->paramNameToValue.find(a);
271 return (res != this->paramNameToValue.end()) ? res->second : a;
278 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
280 : TypeConverter(), origTy(originalType), newTy(newType),
281 paramNameToValue(paramNameToInstantiatedValue) {
283 addConversion([](Type inputTy) {
return inputTy; });
286 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
289 if (inputTy == this->origTy) {
293 if (ArrayAttr inputTyParams = inputTy.getParams()) {
294 SmallVector<Attribute> updated;
295 for (Attribute a : inputTyParams) {
296 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
297 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
299 updated.push_back(convertIfPossible(a));
303 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
310 addConversion([
this](
ArrayType inputTy) {
312 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
313 if (!dimSizes.empty()) {
314 SmallVector<Attribute> updated;
315 for (Attribute a : dimSizes) {
316 updated.push_back(convertIfPossible(a));
318 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
324 addConversion([
this](
TypeVarType inputTy) -> Type {
326 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
327 Type convertedType = tyAttr.getValue();
332 return convertedType;
340 template <
typename Impl,
typename Op,
typename... HandledAttrs>
341 class SymbolUserHelper :
public OpConversionPattern<Op> {
343 const DenseMap<Attribute, Attribute> ¶mNameToValue;
346 TypeConverter &converter, MLIRContext *ctx,
unsigned Benefit,
347 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
349 : OpConversionPattern<Op>(converter, ctx, Benefit),
350 paramNameToValue(paramNameToInstantiatedValue) {}
353 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
355 virtual Attribute getNameAttr(Op)
const = 0;
357 virtual LogicalResult handleDefaultRewrite(
358 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
360 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
364 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
365 auto res = this->paramNameToValue.find(getNameAttr(op));
366 if (res == this->paramNameToValue.end()) {
367 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] no instantiation for " << op <<
'\n');
370 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
371 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
373 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
374 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
378 return TS.Default([&](Attribute a) {
379 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
385 class ClonedStructConstReadOpPattern
386 :
public SymbolUserHelper<
387 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
388 SmallVector<Diagnostic> &diagnostics;
391 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
394 ClonedStructConstReadOpPattern(
395 TypeConverter &converter, MLIRContext *ctx,
396 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
397 SmallVector<Diagnostic> &instantiationDiagnostics
401 : super(converter, ctx, 2, paramNameToInstantiatedValue),
402 diagnostics(instantiationDiagnostics) {}
406 LogicalResult handleRewrite(
407 Attribute sym,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
409 APInt attrValue = a.getValue();
410 Type origResTy = op.getType();
411 if (llvm::isa<FeltType>(origResTy)) {
413 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
418 if (llvm::isa<IndexType>(origResTy)) {
423 if (origResTy.isSignlessInteger(1)) {
425 if (attrValue.isZero()) {
429 if (!attrValue.isOne()) {
430 Location opLoc = op.getLoc();
431 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
433 if (getContext()->shouldPrintOpOnDiagnostic()) {
434 diag.attachNote(opLoc) <<
"see current operation: " << *op;
436 diag.attachNote(UnknownLoc::get(getContext()))
438 << sym <<
"\" for this call";
439 diagnostics.push_back(std::move(diag));
444 return op->emitOpError().append(
"unexpected result type ", origResTy);
447 LogicalResult handleRewrite(
448 Attribute,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
455 class ClonedStructFieldReadOpPattern
456 :
public SymbolUserHelper<
457 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
459 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
462 ClonedStructFieldReadOpPattern(
463 TypeConverter &converter, MLIRContext *ctx,
464 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
468 : super(converter, ctx, 2, paramNameToInstantiatedValue) {}
470 Attribute getNameAttr(
FieldReadOp op)
const override {
474 template <
typename Attr>
475 LogicalResult handleRewrite(
476 Attribute,
FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
478 rewriter.modifyOpInPlace(op, [&]() {
485 LogicalResult matchAndRewrite(
486 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
488 if (tableOffsetIsntSymbol(op)) {
492 return super::matchAndRewrite(op, adaptor, rewriter);
496 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
498 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.
getDefinition(symTables, rootMod);
500 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
506 MLIRContext *ctx = origStruct.getContext();
509 DenseMap<Attribute, Attribute> paramNameToConcrete;
513 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
516 ArrayAttr reducedParamNameList =
nullptr;
518 ArrayAttr reducedCallerParams =
nullptr;
520 ArrayAttr paramNames = typeAtDef.
getParams();
524 assert(paramNames.size() == typeAtCallerParams.size());
526 SmallVector<Attribute> remainingNames;
527 SmallVector<Attribute> nonConcreteParams;
528 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
529 Attribute next = typeAtCallerParams[i];
530 if (isConcreteAttr<false>(next)) {
531 paramNameToConcrete[paramNames[i]] = next;
532 attrsForInstantiatedNameSuffix.push_back(next);
534 remainingNames.push_back(paramNames[i]);
535 nonConcreteParams.push_back(next);
536 attrsForInstantiatedNameSuffix.push_back(
nullptr);
540 assert(remainingNames.size() == nonConcreteParams.size());
541 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
542 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
544 if (paramNameToConcrete.empty()) {
545 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
548 if (!remainingNames.empty()) {
549 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
550 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
559 typeAtCaller.
getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
564 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>();
565 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
570 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
571 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
572 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
573 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
579 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
580 ConversionTarget target =
584 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
588 patterns.add<ClonedStructConstReadOpPattern>(
589 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
591 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
592 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
593 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
596 return newRemoteType;
600 StructCloner(ConversionTracker &tracker, ModuleOp root)
601 : tracker_(tracker), rootMod(root), symTables() {}
603 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
604 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
605 if (ArrayAttr params = orig.
getParams()) {
606 return genClone(orig, params.getValue());
608 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
613class ParameterizedStructUseTypeConverter :
public TypeConverter {
614 ConversionTracker &tracker_;
618 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
619 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
621 addConversion([](Type inputTy) {
return inputTy; });
625 if (
auto opt = tracker_.getInstantiation(inputTy)) {
631 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
632 if (failed(cloneRes)) {
637 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
638 <<
" as " << newTy <<
'\n'
640 tracker_.recordInstantiation(inputTy, newTy);
644 addConversion([
this](
ArrayType inputTy) {
645 return inputTy.cloneWith(convertType(inputTy.getElementType()));
650class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
651 ConversionTracker &tracker_;
654 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
657 : OpConversionPattern<CallOp>(converter, ctx, 2), tracker_(tracker) {}
659 LogicalResult matchAndRewrite(
CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter)
661 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
664 SmallVector<Type> newResultTypes;
665 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
666 return op->emitError(
"Could not convert Op result types.");
669 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
679 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
680 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
681 tracker_.reportDelayedDiagnostics(newStTy, op);
685 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
686 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
690 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
692 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
695 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
701class FieldDefOpPattern :
public OpConversionPattern<FieldDefOp> {
703 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
706 : OpConversionPattern<FieldDefOp>(converter, ctx, 2) {}
708 LogicalResult matchAndRewrite(
709 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
711 LLVM_DEBUG(llvm::dbgs() <<
"[FieldDefOpPattern] FieldDefOp: " << op <<
'\n');
713 Type oldFieldType = op.
getType();
714 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
715 if (oldFieldType == newFieldType) {
719 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.
setType(newFieldType); });
724LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
725 MLIRContext *ctx = modOp.getContext();
726 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
729 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
730 return applyPartialConversion(modOp, target, std::move(patterns));
785std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
786 SmallVector<int64_t> res;
787 for (OpFoldResult ofr : ofrs) {
788 std::optional<int64_t> cv = getConstantIntValue(ofr);
789 if (!cv.has_value()) {
792 res.push_back(cv.value());
797struct AffineMapFolder {
799 OperandRangeRange mapOpGroups;
800 DenseI32ArrayAttr dimsPerGroup;
801 ArrayRef<Attribute> paramsOfStructTy;
805 SmallVector<SmallVector<Value>> mapOpGroups;
806 SmallVector<int32_t> dimsPerGroup;
807 SmallVector<Attribute> paramsOfStructTy;
810 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
811 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
812 return ValueRange(grp);
817 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
818 if (in.mapOpGroups.empty()) {
823 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
824 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
827 for (Attribute sizeAttr : in.paramsOfStructTy) {
828 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
829 ValueRange currMapOps = in.mapOpGroups[idx++];
834 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
836 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
839 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
840 SmallVector<Attribute> result;
841 bool hasPoison =
false;
842 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
843 return rewriter.getIndexAttr(v);
845 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
847 LLVM_DEBUG(op->emitRemark().append(
848 "Cannot fold affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
849 " due to divide by 0 or modulus with negative divisor"
853 if (failed(foldResult)) {
854 LLVM_DEBUG(op->emitRemark().append(
855 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" failed"
859 if (result.size() != 1) {
860 LLVM_DEBUG(op->emitRemark().append(
861 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" produced ",
862 result.size(),
" results but expected 1"
866 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
867 out.paramsOfStructTy.push_back(result[0]);
871 out.mapOpGroups.emplace_back(currMapOps);
872 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
875 out.paramsOfStructTy.push_back(sizeAttr);
877 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
879 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
880 "produced wrong number of dimensions"
888class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
889 ConversionTracker &tracker_;
892 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
893 : OpRewritePattern(ctx), tracker_(tracker) {}
895 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
898 AffineMapFolder::Output out;
899 AffineMapFolder::Input in = {
904 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
909 if (newResultType == oldResultType) {
914 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
916 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
917 << newResultType <<
" in \"" << op <<
"\"\n"
920 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
927class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
928 ConversionTracker &tracker_;
931 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
932 : OpRewritePattern(ctx), tracker_(tracker) {}
934 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
939 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
941 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
948 AffineMapFolder::Output out;
949 AffineMapFolder::Input in = {
954 if (!in.mapOpGroups.empty()) {
956 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
960 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
966 if (callArgTypes.empty()) {
970 SymbolTableCollection tables;
972 if (failed(lookupRes)) {
975 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
979 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
980 "result type params: "
986 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
987 if (newRetTy == oldRetTy) {
995 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
996 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
998 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
999 ", but found ", oldRetTy
1003 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1005 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1006 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1008 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1015 inline LogicalResult instantiateViaTargetType(
1016 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1017 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1022 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1024 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1030 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1032 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1034 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1036 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1042 assert(unifies &&
"should have been checked by verifiers");
1045 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1054 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1055 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1056 [&unifications](std::tuple<Attribute, Attribute> p) {
1057 Attribute fromCall = std::get<1>(p);
1060 if (!isConcreteAttr<>(fromCall)) {
1061 Attribute fromTgt = std::get<0>(p);
1063 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1064 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1066 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1067 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1068 if (it != unifications.end()) {
1069 Attribute unifiedAttr = it->second;
1071 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1073 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1082 out.paramsOfStructTy = newReturnStructParams;
1083 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1084 assert(out.mapOpGroups.empty() &&
"post-condition");
1085 assert(out.dimsPerGroup.empty() &&
"post-condition");
1090LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1091 MLIRContext *ctx = modOp.getContext();
1092 RewritePatternSet patterns(ctx);
1094 InstantiateAtCreateArrayOp,
1095 InstantiateAtCallOpCompute
1098 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1106class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1107 ConversionTracker &tracker_;
1110 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1111 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1113 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1115 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1116 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1121 Type newResultElemType =
nullptr;
1122 for (Operation *user : createResult.getUsers()) {
1123 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1124 if (writeOp.getArrRef() != createResult) {
1127 Type writeRValueType = writeOp.getRvalue().getType();
1128 if (writeRValueType == oldResultElemType) {
1131 if (newResultElemType && newResultElemType != writeRValueType) {
1134 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1135 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1139 newResultElemType = writeRValueType;
1142 if (!newResultElemType) {
1146 if (!tracker_.isLegalConversion(
1147 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1152 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1154 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1162LogicalResult updateArrayElemFromArrAccessOp(
1164 PatternRewriter &rewriter
1171 if (oldArrType == newArrType ||
1172 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1175 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1177 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1184class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1185 ConversionTracker &tracker_;
1188 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1189 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1191 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1192 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1196class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1197 ConversionTracker &tracker_;
1200 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1201 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1203 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1204 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1209class UpdateFieldDefTypeFromWrite final :
public OpRewritePattern<FieldDefOp> {
1210 ConversionTracker &tracker_;
1213 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1214 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1216 LogicalResult matchAndRewrite(
FieldDefOp op, PatternRewriter &rewriter)
const override {
1219 assert(succeeded(parentRes) &&
"FieldDefOp parent is always StructDefOp");
1223 Type newType =
nullptr;
1225 std::optional<Location> newTypeLoc = std::nullopt;
1226 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1227 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1228 Type writeToType = writeOp.getVal().getType();
1229 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] checking " << writeOp <<
'\n');
1232 newType = writeToType;
1233 newTypeLoc = writeOp.getLoc();
1234 }
else if (writeToType != newType) {
1240 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateFieldDefTypeFromWrite")) {
1241 if (tracker_.isLegalConversion(newType, writeToType,
"UpdateFieldDefTypeFromWrite")) {
1243 newType = writeToType;
1244 newTypeLoc = writeOp.getLoc();
1247 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1251 "' with different value types"
1254 diag.attachNote(*newTypeLoc).append(
"type written here is ", newType);
1256 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1264 if (!newType || newType == op.
getType()) {
1268 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateFieldDefTypeFromWrite")) {
1271 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1272 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFieldDefTypeFromWrite] updated type of " << op <<
'\n');
1279SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1280 SmallVector<std::unique_ptr<Region>> newRegions;
1281 for (Region ®ion : op->getRegions()) {
1282 auto newRegion = std::make_unique<Region>();
1283 newRegion->takeBody(region);
1284 newRegions.push_back(std::move(newRegion));
1293class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1294 ConversionTracker &tracker_;
1297 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1298 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1300 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1301 SmallVector<Type, 1> inferredResultTypes;
1302 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1303 LogicalResult result = retTypeFn.inferReturnTypes(
1304 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1305 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1307 if (failed(result)) {
1310 if (op->getResultTypes() == inferredResultTypes) {
1314 if (!tracker_.areLegalConversions(
1315 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1321 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1322 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1323 Operation *newOp = rewriter.create(
1324 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1325 op->getAttrs(), op->getSuccessors(), newRegions
1327 rewriter.replaceOp(op, newOp);
1328 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1334class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1335 ConversionTracker &tracker_;
1338 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1339 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1341 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1342 Region &body = op.getFunctionBody();
1346 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1347 assert(retOp &&
"final op in body region must be return");
1348 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1351 if (oldFuncTy.getResults() == tyFromReturnOp) {
1355 if (!tracker_.areLegalConversions(
1356 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1361 rewriter.modifyOpInPlace(op, [&]() {
1362 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1365 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
1376class UpdateGlobalCallOpTypes final :
public OpRewritePattern<CallOp> {
1377 ConversionTracker &tracker_;
1380 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1381 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1383 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1384 SymbolTableCollection tables;
1386 if (failed(lookupRes)) {
1389 FuncDefOp targetFunc = lookupRes->get();
1394 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
1398 if (!tracker_.areLegalConversions(
1400 "UpdateGlobalCallOpTypes"
1405 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateGlobalCallOpTypes] replaced " << op);
1407 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1414LogicalResult updateFieldRefValFromFieldDef(
1417 SymbolTableCollection tables;
1422 Type oldResultType = op.
getVal().getType();
1423 Type newResultType = def->get().getType();
1424 if (oldResultType == newResultType ||
1425 !tracker.isLegalConversion(oldResultType, newResultType,
"updateFieldRefValFromFieldDef")) {
1428 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
1430 llvm::dbgs() <<
"[updateFieldRefValFromFieldDef] updated value type in " << op <<
'\n'
1438class UpdateFieldReadValFromDef final :
public OpRewritePattern<FieldReadOp> {
1439 ConversionTracker &tracker_;
1442 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1443 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1445 LogicalResult matchAndRewrite(
FieldReadOp op, PatternRewriter &rewriter)
const override {
1446 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1451class UpdateFieldWriteValFromDef final :
public OpRewritePattern<FieldWriteOp> {
1452 ConversionTracker &tracker_;
1455 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1456 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1458 LogicalResult matchAndRewrite(
FieldWriteOp op, PatternRewriter &rewriter)
const override {
1459 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1463LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1464 MLIRContext *ctx = modOp.getContext();
1465 RewritePatternSet patterns(ctx);
1470 UpdateInferredResultTypes,
1472 UpdateGlobalCallOpTypes,
1473 UpdateFuncTypeFromReturn,
1474 UpdateNewArrayElemFromWrite,
1475 UpdateArrayElemFromArrRead,
1476 UpdateArrayElemFromArrWrite,
1477 UpdateFieldDefTypeFromWrite,
1478 UpdateFieldReadValFromDef,
1479 UpdateFieldWriteValFromDef
1482 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1490 SymbolTableCollection tables;
1493 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1501struct FromKeepSet :
public CleanupBase {
1502 using CleanupBase::CleanupBase;
1507 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1509 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1511 rootMod.walk([&roots](Operation *op) {
1515 if (!fdef.isInStruct()) {
1524 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1525 for (
size_t i = 0; i < roots.size(); ++i) {
1526 SymbolOpInterface keepRoot = roots[i];
1527 LLVM_DEBUG({ llvm::dbgs() <<
"[EraseUnreachable] root: " << keepRoot <<
'\n'; });
1529 assert(keepRootNode &&
"every struct def must be in the def tree");
1530 for (
const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1532 llvm::dbgs() <<
"[EraseUnreachable] can reach: " << reachableDefNode->getOp() <<
'\n';
1534 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1539 if (
const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1541 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1543 llvm::dbgs() <<
"[EraseUnreachable] uses symbol: "
1544 << usedSymbolNode->getSymbolPath() <<
'\n';
1548 if (usedSymbolNode->isStructParam()) {
1552 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1553 if (failed(lookupRes)) {
1554 LLVM_DEBUG(useGraph.dumpToDotFile());
1558 if (lookupRes->viaInclude()) {
1561 if (
StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1562 bool insertRes = roots.insert(asStruct);
1565 llvm::dbgs() <<
"[EraseUnreachable] found another root: " << asStruct <<
'\n';
1575 rootMod.walk([
this, &symbolsToKeep](
StructDefOp op) {
1578 if (!symbolsToKeep.contains(n)) {
1579 LLVM_DEBUG(llvm::dbgs() <<
"[EraseUnreachable] removing: " << op.getSymName() <<
'\n');
1583 return WalkResult::skip();
1590struct FromEraseSet :
public CleanupBase {
1595 DenseSet<SymbolRefAttr> &&tryToErasePaths
1597 : CleanupBase(root, symDefTree, symUseGraph) {
1599 for (SymbolRefAttr path : tryToErasePaths) {
1600 Operation *lookupFrom = rootMod.getOperation();
1602 assert(succeeded(res) &&
"inputs must be valid StructDefOp references");
1603 if (!res->viaInclude()) {
1604 tryToErase.insert(res->get());
1609 LogicalResult eraseUnusedStructs() {
1612 collectSafeToErase(sd);
1617 for (
auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1618 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1619 visitedPlusSafetyResult.erase(it);
1622 for (
auto &[sym, _] : visitedPlusSafetyResult) {
1623 LLVM_DEBUG(llvm::dbgs() <<
"[EraseIfUnused] removing: " << sym.getNameAttr() <<
'\n');
1629 const DenseSet<StructDefOp> &getTryToEraseSet()
const {
return tryToErase; }
1633 DenseSet<StructDefOp> tryToErase;
1637 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1639 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1643 bool collectSafeToErase(SymbolOpInterface check) {
1647 auto visited = visitedPlusSafetyResult.find(check);
1648 if (visited != visitedPlusSafetyResult.end()) {
1649 return visited->second;
1653 if (
StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1654 if (!tryToErase.contains(sd)) {
1655 visitedPlusSafetyResult[check] =
false;
1662 visitedPlusSafetyResult[check] =
true;
1666 if (collectSafeToErase(defTree.lookupNode(check))) {
1667 auto useNode = useGraph.lookupNode(check);
1668 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1669 if (!useNode || collectSafeToErase(useNode)) {
1675 visitedPlusSafetyResult[check] =
false;
1683 if (SymbolOpInterface checkOp = p->getOp()) {
1684 return collectSafeToErase(checkOp);
1694 if (SymbolOpInterface checkOp = cachedLookup(p)) {
1695 if (!collectSafeToErase(checkOp)) {
1708 assert(node &&
"must provide a node");
1710 auto fromCache = lookupCache.find(node);
1711 if (fromCache != lookupCache.end()) {
1712 return fromCache->second;
1716 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
1717 assert(lookupRes->get() !=
nullptr &&
"lookup must return an Operation");
1722 SymbolOpInterface actualRes =
1723 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1725 lookupCache[node] = actualRes;
1726 assert((!actualRes == lookupRes->viaInclude()) &&
"not found iff included");