34#include <mlir/IR/BuiltinOps.h>
35#include <mlir/Transforms/InliningUtils.h>
36#include <mlir/Transforms/WalkPatternRewriteDriver.h>
38#include <llvm/ADT/PostOrderIterator.h>
39#include <llvm/ADT/SmallPtrSet.h>
40#include <llvm/ADT/SmallVector.h>
41#include <llvm/ADT/StringMap.h>
42#include <llvm/ADT/TypeSwitch.h>
43#include <llvm/Support/Debug.h>
50#define GEN_PASS_DECL_INLINESTRUCTSPASS
51#define GEN_PASS_DEF_INLINESTRUCTSPASS
60#define DEBUG_TYPE "llzk-inline-structs"
69using SrcStructFieldToCloneInDest = std::map<StringRef, DestCloneOfSrcStructField>;
72using DestToSrcToClonedSrcInDest =
73 DenseMap<DestFieldWithSrcStructType, SrcStructFieldToCloneInDest>;
77static inline Value getSelfValue(
FuncDefOp f) {
83 llvm_unreachable(
"expected \"@compute\" or \"@constrain\" function");
97static FailureOr<FieldWriteOp>
98findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
100 for (Operation *user : writtenValue.getUsers()) {
101 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(user)) {
103 if (writeOp.getVal() == writtenValue) {
106 auto diag = emitError().append(
"result should not be written to more than one field.");
107 diag.attachNote(foundWrite.getLoc()).append(
"written here");
108 diag.attachNote(writeOp.getLoc()).append(
"written here");
111 foundWrite = writeOp;
118 return emitError().append(
"result should be written to a field.");
126static bool combineHelper(
130 LLVM_DEBUG({ llvm::dbgs() <<
"[combineHelper] " << readOp <<
" => " << destFieldRefOp <<
'\n'; });
132 auto srcToClone = destToSrcToClone.find(getDef(tables, destFieldRefOp));
133 if (srcToClone == destToSrcToClone.end()) {
136 SrcStructFieldToCloneInDest oldToNewFields = srcToClone->second;
137 auto resNewField = oldToNewFields.find(readOp.
getFieldName());
138 if (resNewField == oldToNewFields.end()) {
143 OpBuilder builder(readOp);
145 readOp.getLoc(), readOp.getType(), destFieldRefOp.
getComponent(),
146 resNewField->second.getNameAttr()
148 readOp.replaceAllUsesWith(newRead.getOperation());
166static bool combineReadChain(
168 const DestToSrcToClonedSrcInDest &destToSrcToClone
170 LLVM_DEBUG({ llvm::dbgs() <<
"[combineReadChain] " << readOp <<
'\n'; });
173 llvm::dyn_cast_if_present<FieldReadOp>(readOp.
getComponent().getDefiningOp());
174 if (!readThatDefinesBaseComponent) {
177 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
196static LogicalResult combineNewThenReadChain(
198 const DestToSrcToClonedSrcInDest &destToSrcToClone
200 LLVM_DEBUG({ llvm::dbgs() <<
"[combineNewThenReadChain] " << readOp <<
'\n'; });
203 llvm::dyn_cast_if_present<CreateStructOp>(readOp.
getComponent().getDefiningOp());
204 if (!createThatDefinesBaseComponent) {
207 FailureOr<FieldWriteOp> foundWrite =
208 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
209 return createThatDefinesBaseComponent.emitOpError();
211 if (failed(foundWrite)) {
214 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
217static inline FieldReadOp getFieldReadThatDefinesSelfValuePassedToConstrain(
CallOp callOp) {
219 return llvm::dyn_cast_if_present<FieldReadOp>(selfArgFromCall.getDefiningOp());
224struct PendingErasure {
225 SmallPtrSet<Operation *, 8> fieldReadOps;
226 SmallPtrSet<Operation *, 8> fieldWriteOps;
227 SmallVector<CreateStructOp> newStructOps;
228 SmallVector<DestFieldWithSrcStructType> fieldDefs;
233 SymbolTableCollection &tables;
234 PendingErasure &toDelete;
236 StructDefOp srcStruct;
238 StructDefOp destStruct;
240 inline FieldDefOp getDef(FieldRefOpInterface fRef)
const { return ::getDef(tables, fRef); }
248 class FieldRefRewriter final :
public OpInterfaceRewritePattern<FieldRefOpInterface> {
256 const SrcStructFieldToCloneInDest &oldToNewFields;
260 FuncDefOp originalFunc, Value newRefBase,
261 const SrcStructFieldToCloneInDest &oldToNewFieldDef
263 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
264 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewFields(oldToNewFieldDef) {}
266 LogicalResult match(FieldRefOpInterface op)
const final {
271 return success(op.getComponent() == oldBaseVal && oldToNewFields.contains(op.getFieldName()));
274 void rewrite(FieldRefOpInterface op, PatternRewriter &rewriter)
const final {
275 rewriter.modifyOpInPlace(op, [
this, &op]() {
276 DestCloneOfSrcStructField newF = oldToNewFields.at(op.getFieldName());
277 op.setFieldName(newF.getSymName());
278 op.getComponentMutable().set(this->newBaseVal);
284 static FuncDefOp cloneWithFieldRefUpdate(std::unique_ptr<FieldRefRewriter> thisPat) {
286 FuncDefOp srcFuncClone = thisPat->funcRef.clone(mapper);
288 thisPat->funcRef = srcFuncClone;
289 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
291 MLIRContext *ctx = thisPat->getContext();
292 RewritePatternSet patterns(ctx, std::move(thisPat));
293 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
302 const StructInliner &data;
303 const DestToSrcToClonedSrcInDest &destToSrcToClone;
307 virtual FieldRefOpInterface getSelfRefField(CallOp callOp) = 0;
308 virtual void processCloneBeforeInlining(FuncDefOp func) {}
309 virtual ~ImplBase() =
default;
312 ImplBase(
const StructInliner &inliner,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
313 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
315 LogicalResult doInlining(FuncDefOp srcFunc, FuncDefOp destFunc) {
317 llvm::dbgs() <<
"[doInlining] SOURCE FUNCTION:\n";
319 llvm::dbgs() <<
"[doInlining] DESTINATION FUNCTION:\n";
323 InlinerInterface inliner(destFunc.getContext());
326 auto callHandler = [
this, &inliner, &srcFunc](CallOp callOp) {
329 assert(succeeded(callOpTarget));
330 if (callOpTarget->get() != srcFunc) {
331 return WalkResult::advance();
336 FieldRefOpInterface selfFieldRefOp = this->getSelfRefField(callOp);
337 if (!selfFieldRefOp) {
339 return WalkResult::interrupt();
345 FuncDefOp srcFuncClone = FieldRefRewriter::cloneWithFieldRefUpdate(
346 std::make_unique<FieldRefRewriter>(
348 this->destToSrcToClone.at(this->data.getDef(selfFieldRefOp))
351 this->processCloneBeforeInlining(srcFuncClone);
354 LogicalResult inlineCallRes =
355 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.
getBody(),
false);
356 if (failed(inlineCallRes)) {
358 return WalkResult::interrupt();
360 srcFuncClone.erase();
362 return WalkResult::skip();
365 auto fieldWriteHandler = [
this](FieldWriteOp writeOp) {
367 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
368 this->data.toDelete.fieldWriteOps.insert(writeOp);
370 return WalkResult::advance();
375 auto fieldReadHandler = [
this](FieldReadOp readOp) {
377 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
378 this->data.toDelete.fieldReadOps.insert(readOp);
381 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
383 : WalkResult::advance();
386 WalkResult walkRes = destFunc.
getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
387 return TypeSwitch<Operation *, WalkResult>(op)
388 .Case<CallOp>(callHandler)
389 .Case<FieldWriteOp>(fieldWriteHandler)
390 .Case<FieldReadOp>(fieldReadHandler)
391 .Default([](Operation *) {
return WalkResult::advance(); });
394 return failure(walkRes.wasInterrupted());
398 class ConstrainImpl :
public ImplBase {
399 using ImplBase::ImplBase;
401 FieldRefOpInterface getSelfRefField(CallOp callOp)
override {
402 LLVM_DEBUG({ llvm::dbgs() <<
"[ConstrainImpl::getSelfRefField] " << callOp <<
'\n'; });
407 FieldRefOpInterface selfFieldRef = getFieldReadThatDefinesSelfValuePassedToConstrain(callOp);
409 selfFieldRef.getComponent().getType() == this->data.destStruct.getType()) {
415 "\" to be passed a value read from a field in the current stuct."
422 class ComputeImpl :
public ImplBase {
423 using ImplBase::ImplBase;
425 FieldRefOpInterface getSelfRefField(CallOp callOp)
override {
426 LLVM_DEBUG({ llvm::dbgs() <<
"[ComputeImpl::getSelfRefField] " << callOp <<
'\n'; });
433 FailureOr<FieldWriteOp> foundWrite =
435 return callOp.emitOpError().append(
"\"@", FUNC_NAME_COMPUTE,
"\" ");
437 return static_cast<FieldRefOpInterface
>(foundWrite.value_or(
nullptr));
440 void processCloneBeforeInlining(FuncDefOp func)
override {
444 func.
getBody().walk([
this](CreateStructOp newStructOp) {
445 if (newStructOp.getType() == this->data.srcStruct.getType()) {
446 this->data.toDelete.newStructOps.push_back(newStructOp);
455 DestToSrcToClonedSrcInDest cloneFields() {
456 DestToSrcToClonedSrcInDest destToSrcToClone;
458 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
459 StructType srcStructType = srcStruct.getType();
460 for (FieldDefOp destField : destStruct.getFieldDefs()) {
461 if (StructType destFieldType = llvm::dyn_cast<StructType>(destField.getType())) {
466 assert(unifications.empty());
468 toDelete.fieldDefs.push_back(destField);
471 SrcStructFieldToCloneInDest &srcToClone = destToSrcToClone[destField];
472 std::vector<FieldDefOp> srcFields = srcStruct.getFieldDefs();
473 if (srcFields.empty()) {
476 OpBuilder builder(destField);
477 std::string newNameBase =
479 for (FieldDefOp srcField : srcFields) {
480 DestCloneOfSrcStructField newF = llvm::cast<FieldDefOp>(builder.clone(*srcField));
481 newF.setName(builder.getStringAttr(newNameBase +
'+' + newF.getName()));
482 srcToClone[srcField.getSymNameAttr()] = newF;
484 destStructSymTable.insert(newF);
488 return destToSrcToClone;
492 inline LogicalResult inlineConstrainCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
493 return ConstrainImpl(*
this, destToSrcToClone)
494 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
498 inline LogicalResult inlineComputeCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
499 return ComputeImpl(*
this, destToSrcToClone)
500 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
505 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp
from, StructDefOp into
507 : tables(tbls), toDelete(opsToDelete), srcStruct(
from), destStruct(into) {}
509 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
511 llvm::dbgs() <<
"[StructInliner] merge " << srcStruct.getSymNameAttr() <<
" into "
512 << destStruct.getSymNameAttr() <<
'\n'
515 DestToSrcToClonedSrcInDest destToSrcToClone = cloneFields();
516 if (failed(inlineConstrainCall(destToSrcToClone)) ||
517 failed(inlineComputeCall(destToSrcToClone))) {
520 return destToSrcToClone;
526 { t.contains(p) } -> std::convertible_to<bool>;
530template <
typename... PendingDeletionSets>
532class DanglingUseHandler {
533 SymbolTableCollection &tables;
534 const DestToSrcToClonedSrcInDest &destToSrcToClone;
535 std::tuple<
const PendingDeletionSets &...> otherRefsToBeDeleted;
539 SymbolTableCollection &symTables,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef,
540 const PendingDeletionSets &...otherRefsPendingDeletion
542 : tables(symTables), destToSrcToClone(destToSrcToCloneRef),
543 otherRefsToBeDeleted(otherRefsPendingDeletion...) {}
550 LogicalResult handle(Operation *op)
const {
551 if (op->use_empty()) {
556 llvm::dbgs() <<
"[DanglingUseHandler::handle] op: " << *op <<
'\n';
557 llvm::dbgs() <<
"[DanglingUseHandler::handle] in function: "
558 << op->getParentOfType<
FuncDefOp>() <<
'\n';
560 for (OpOperand &
use : llvm::make_early_inc_range(op->getUses())) {
561 if (
CallOp c = llvm::dyn_cast<CallOp>(
use.getOwner())) {
562 if (failed(handleUseInCallOp(
use, c, op))) {
566 Operation *user =
use.getOwner();
568 if (!opWillBeDeleted(user)) {
569 return op->emitOpError()
571 "with use in '", user->getName().getStringRef(),
572 "' is not (currently) supported by this pass."
574 .attachNote(user->getLoc())
575 .append(
"used by this operation");
580 if (!op->use_empty()) {
581 for (Operation *user : op->getUsers()) {
582 if (!opWillBeDeleted(user)) {
583 llvm::errs() <<
"Op has remaining use(s) that could not be removed: " << *op <<
'\n';
584 llvm_unreachable(
"Expected all uses to be removed");
592 inline LogicalResult handleUseInCallOp(OpOperand &
use,
CallOp inCall, Operation *origin)
const {
594 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] use in call: " << inCall <<
'\n'
596 unsigned argIdx =
use.getOperandNumber() - inCall.
getArgOperands().getBeginOperandIndex();
598 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] at index: " << argIdx <<
'\n'
602 if (failed(tgtFuncRes)) {
604 ->emitOpError(
"as argument to an unknown function is not supported by this pass.")
605 .attachNote(inCall.getLoc())
606 .append(
"used by this call");
610 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] call target: " << tgtFunc <<
'\n'
612 if (tgtFunc.isExternal()) {
616 ->emitOpError(
"as argument to a no-body free function is not supported by this pass.")
617 .attachNote(inCall.getLoc())
618 .append(
"used by this call");
622 .template Case<FieldReadOp>([](
auto p) {
return p; })
623 .
template Case<CreateStructOp>([](
auto p) {
624 return findOpThatStoresSubcmp(p, [&p]() {
return p.emitOpError(); }).value_or(
nullptr);
625 }).Default([](Operation *p) {
626 llvm::errs() <<
"Encountered unexpected op: "
627 << (p ? p->getName().getStringRef() :
"<<null>>") <<
'\n';
628 llvm_unreachable(
"Unexpected op kind");
632 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] field ref op for param: "
635 if (!paramFromField) {
638 const SrcStructFieldToCloneInDest &newFields =
639 destToSrcToClone.at(getDef(tables, paramFromField));
641 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] fields to split: "
646 splitFunctionParam(tgtFunc, argIdx, newFields);
648 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target: " << tgtFunc
650 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target type: "
656 OpBuilder builder(inCall);
657 SmallVector<Value> splitArgs;
661 for (
auto [origName, newFieldRef] : newFields) {
663 inCall.getLoc(), newFieldRef.getType(), originalBaseVal, newFieldRef.getNameAttr()
669 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
672 inCall.replaceAllUsesWith(builder.create<
CallOp>(
678 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED function: "
679 << origin->getParentOfType<
FuncDefOp>() <<
'\n';
685 inline bool opWillBeDeleted(Operation *otherOp)
const {
686 return std::apply([&](
const auto &...sets) {
687 return ((sets.contains(otherOp)) || ...);
688 }, otherRefsToBeDeleted);
695 static void splitFunctionParam(
696 FuncDefOp func,
unsigned paramIdx,
const SrcStructFieldToCloneInDest &nameToNewField
700 const SrcStructFieldToCloneInDest &newFields;
703 Impl(
unsigned paramIdx,
const SrcStructFieldToCloneInDest &nameToNewField)
704 : inputIdx(paramIdx), newFields(nameToNewField) {}
707 SmallVector<Type>
convertInputs(ArrayRef<Type> origTypes)
override {
708 SmallVector<Type> newTypes(origTypes);
709 auto it = newTypes.erase(newTypes.begin() + inputIdx);
710 for (
auto [_, newField] : newFields) {
711 newTypes.insert(it, newField.getType());
716 SmallVector<Type>
convertResults(ArrayRef<Type> origTypes)
override {
717 return SmallVector<Type>(origTypes);
722 SmallVector<Attribute> newAttrs(origAttrs.getValue());
723 newAttrs.insert(newAttrs.begin() + inputIdx, newFields.size() - 1, origAttrs[inputIdx]);
724 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
733 Value oldStructRef = entryBlock.getArgument(inputIdx);
737 llvm::StringMap<BlockArgument> fieldNameToNewArg;
738 Location loc = oldStructRef.getLoc();
739 unsigned idx = inputIdx;
740 for (
auto [fieldName, newField] : newFields) {
742 BlockArgument newArg = entryBlock.insertArgument(++idx, newField.getType(), loc);
743 fieldNameToNewArg[fieldName] = newArg;
748 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
749 if (FieldReadOp readOp = llvm::dyn_cast<FieldReadOp>(oldBlockArgUse.getOwner())) {
751 BlockArgument newArg = fieldNameToNewArg.at(readOp.
getFieldName());
752 rewriter.replaceAllUsesWith(readOp, newArg);
753 rewriter.eraseOp(readOp);
758 llvm::errs() <<
"Unexpected use of " << oldBlockArgUse.get() <<
" in "
759 << *oldBlockArgUse.getOwner() <<
'\n';
760 llvm_unreachable(
"Not yet implemented");
764 entryBlock.eraseArgument(inputIdx);
767 IRRewriter rewriter(func.getContext());
768 Impl(paramIdx, nameToNewField).convert(func, rewriter);
772static LogicalResult finalizeStruct(
773 SymbolTableCollection &tables,
StructDefOp caller, PendingErasure &&toDelete,
774 DestToSrcToClonedSrcInDest &&destToSrcToClone
777 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
778 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
779 llvm::dbgs() <<
'\n';
784 combineReadChain(readOp, tables, destToSrcToClone);
788 auto res = computeFn.walk([&tables, &destToSrcToClone, &computeSelfVal](
FieldReadOp readOp) {
789 combineReadChain(readOp, tables, destToSrcToClone);
793 return WalkResult::advance();
795 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
796 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
798 if (res.wasInterrupted()) {
803 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
804 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
805 llvm::dbgs() <<
'\n';
806 llvm::dbgs() <<
"[finalizeStruct] ops marked for deletion:\n";
807 for (Operation *op : toDelete.fieldReadOps) {
808 llvm::dbgs().indent(2) << *op <<
'\n';
810 for (Operation *op : toDelete.fieldWriteOps) {
811 llvm::dbgs().indent(2) << *op <<
'\n';
814 llvm::dbgs().indent(2) << op <<
'\n';
816 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
817 llvm::dbgs().indent(2) << op <<
'\n';
823 DanglingUseHandler<SmallPtrSet<Operation *, 8>, SmallPtrSet<Operation *, 8>> useHandler(
824 tables, destToSrcToClone, toDelete.fieldWriteOps, toDelete.fieldReadOps
827 if (failed(useHandler.handle(op))) {
833 for (Operation *op : toDelete.fieldWriteOps) {
834 if (failed(useHandler.handle(op))) {
839 for (Operation *op : toDelete.fieldReadOps) {
840 if (failed(useHandler.handle(op))) {
849 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
850 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
851 assert(op.getParentOp() == caller);
852 callerSymTab.erase(op);
861 for (
auto &[caller, callees] : plan) {
864 PendingErasure toDelete;
866 DestToSrcToClonedSrcInDest aggregateReplacements;
869 FailureOr<DestToSrcToClonedSrcInDest> res =
870 StructInliner(tables, toDelete, toInline, caller).doInline();
875 for (
auto &[k, v] : res.value()) {
876 assert(!aggregateReplacements.contains(k) &&
"duplicate not possible");
877 aggregateReplacements[k] = std::move(v);
881 LogicalResult finalizeResult =
882 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
883 if (failed(finalizeResult)) {
893 static uint64_t complexity(FuncDefOp f) {
894 uint64_t complexity = 0;
895 f.
getBody().walk([&complexity](Operation *op) {
896 if (llvm::isa<felt::MulFeltOp>(op)) {
898 }
else if (
auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
900 }
else if (
auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
909 static FailureOr<FuncDefOp>
910 getIfStructConstrain(
const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
912 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
913 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
923 static inline StructDefOp getParentStruct(FuncDefOp func) {
926 assert(succeeded(currentNodeParentStruct));
927 return currentNodeParentStruct.value();
931 inline bool exceedsMaxComplexity(uint64_t check) {
932 return maxComplexity > 0 && check > maxComplexity;
937 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
949 WalkResult res = currentFunc.walk([](CallOp c) {
950 return getFieldReadThatDefinesSelfValuePassedToConstrain(c)
951 ? WalkResult::interrupt()
952 : WalkResult::advance();
958 return res.wasInterrupted();
965 inline FailureOr<InliningPlan>
966 makePlan(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
968 llvm::dbgs() <<
"Running InlineStructsPass with max complexity ";
969 if (maxComplexity == 0) {
970 llvm::dbgs() <<
"unlimited";
972 llvm::dbgs() << maxComplexity;
974 llvm::dbgs() <<
'\n';
977 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
989 for (
const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
990 LLVM_DEBUG(llvm::dbgs() <<
"\ncurrentNode = " << currentNode->toString());
991 if (!currentNode->isRealNode()) {
994 if (currentNode->isStructParam()) {
996 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
1000 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
1001 return reportLoc->emitError(
"Cannot inline structs with parameters.");
1003 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
1004 if (failed(currentFuncOpt)) {
1007 FuncDefOp currentFunc = currentFuncOpt.value();
1008 uint64_t currentComplexity = complexity(currentFunc);
1010 if (exceedsMaxComplexity(currentComplexity)) {
1011 complexityMemo[currentNode] = currentComplexity;
1016 SmallVector<StructDefOp> successorsToMerge;
1017 for (
const SymbolUseGraphNode *successor : currentNode->successorIter()) {
1018 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"successor: " << successor->toString() <<
'\n');
1020 auto memoResult = complexityMemo.find(successor);
1021 if (memoResult == complexityMemo.end()) {
1024 uint64_t sComplexity = memoResult->second;
1026 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
1027 "addition will overflow"
1029 uint64_t potentialComplexity = currentComplexity + sComplexity;
1030 if (!exceedsMaxComplexity(potentialComplexity)) {
1031 currentComplexity = potentialComplexity;
1032 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
1033 assert(succeeded(successorFuncOpt));
1034 FuncDefOp successorFunc = successorFuncOpt.value();
1035 if (canInline(currentFunc, successorFunc)) {
1036 successorsToMerge.push_back(getParentStruct(successorFunc));
1040 complexityMemo[currentNode] = currentComplexity;
1041 if (!successorsToMerge.empty()) {
1042 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
1046 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1047 llvm::dbgs() <<
"InlineStructsPass plan:\n";
1048 for (
auto &[caller, callees] : retVal) {
1049 llvm::dbgs().indent(2) <<
"inlining the following into \"" << caller.
getSymName() <<
"\"\n";
1050 for (StructDefOp c : callees) {
1051 llvm::dbgs().indent(4) <<
"\"" << c.getSymName() <<
"\"\n";
1054 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1060 void runOnOperation()
override {
1061 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
1064 SymbolTableCollection tables;
1065 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
1067 signalPassFailure();
1072 signalPassFailure();
1081 return std::make_unique<InlineStructsPass>();
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan)
mlir::SmallVector< std::pair< llzk::component::StructDefOp, mlir::SmallVector< llzk::component::StructDefOp > > > InliningPlan
Maps caller struct to callees that should be inlined.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::llvm::StringRef getFieldName()
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the FieldRefOp.
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
void print(::mlir::OpAsmPrinter &_odsPrinter)
::mlir::Operation::operand_range getArgOperands()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FunctionType getFunctionType()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
::mlir::Region & getBody()
std::string toStringOne(const T &value)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)