35#include <mlir/IR/BuiltinOps.h>
36#include <mlir/Transforms/InliningUtils.h>
38#include <llvm/ADT/PostOrderIterator.h>
39#include <llvm/ADT/SmallVector.h>
40#include <llvm/ADT/StringMap.h>
41#include <llvm/ADT/TypeSwitch.h>
42#include <llvm/Support/Debug.h>
47#define GEN_PASS_DECL_INLINESTRUCTSPASS
48#define GEN_PASS_DEF_INLINESTRUCTSPASS
57#define DEBUG_TYPE "llzk-inline-structs"
66using SrcStructFieldToCloneInDest = std::map<StringRef, DestCloneOfSrcStructField>;
69using DestToSrcToClonedSrcInDest =
70 DenseMap<DestFieldWithSrcStructType, SrcStructFieldToCloneInDest>;
78 llvm_unreachable(
"expected \"compute\" or \"constrain\" function");
95 auto srcToClone = destToSrcToClone.find(getDef(tables, destFieldRefOp));
96 if (srcToClone == destToSrcToClone.end()) {
99 SrcStructFieldToCloneInDest oldToNewFields = srcToClone->second;
100 auto resNewField = oldToNewFields.find(readOp.
getFieldName());
101 if (resNewField == oldToNewFields.end()) {
106 OpBuilder builder(readOp);
108 readOp.getLoc(), readOp.getType(), destFieldRefOp.
getComponent(),
109 resNewField->second.getNameAttr()
111 readOp.replaceAllUsesWith(newRead.getOperation());
129bool combineReadChain(
131 const DestToSrcToClonedSrcInDest &destToSrcToClone
134 llvm::dyn_cast_if_present<FieldReadOp>(readOp.
getComponent().getDefiningOp());
135 if (!readThatDefinesBaseComponent) {
138 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
143FailureOr<FieldWriteOp>
144findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
146 for (Operation *user : writtenValue.getUsers()) {
147 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(user)) {
149 if (writeOp.getVal() == writtenValue) {
152 auto diag = emitError().append(
"result should not be written to more than one field.");
153 diag.attachNote(foundWrite.getLoc()).append(
"written here");
154 diag.attachNote(writeOp.getLoc()).append(
"written here");
157 foundWrite = writeOp;
164 return emitError().append(
"result should be written to a field.");
185LogicalResult combineNewThenReadChain(
187 const DestToSrcToClonedSrcInDest &destToSrcToClone
190 llvm::dyn_cast_if_present<CreateStructOp>(readOp.
getComponent().getDefiningOp());
191 if (!createThatDefinesBaseComponent) {
194 FailureOr<FieldWriteOp> foundWrite =
195 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
196 return createThatDefinesBaseComponent.emitOpError();
198 if (failed(foundWrite)) {
201 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
204inline FieldReadOp getFieldReadThatDefinesSelfValuePassedToConstrain(
CallOp callOp) {
206 return llvm::dyn_cast_if_present<FieldReadOp>(selfArgFromCall.getDefiningOp());
211struct PendingErasure {
212 SmallVector<FieldRefOpInterface> fieldRefOps;
213 SmallVector<CreateStructOp> newStructOps;
214 SmallVector<DestFieldWithSrcStructType> fieldDefs;
218 SymbolTableCollection &tables;
219 PendingErasure &toDelete;
221 StructDefOp srcStruct;
223 StructDefOp destStruct;
225 inline FieldDefOp getDef(FieldRefOpInterface fRef)
const { return ::getDef(tables, fRef); }
233 class FieldRefRewriter final :
public OpInterfaceRewritePattern<FieldRefOpInterface> {
241 const SrcStructFieldToCloneInDest &oldToNewFields;
245 FuncDefOp originalFunc, Value newRefBase,
246 const SrcStructFieldToCloneInDest &oldToNewFieldDef
248 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
249 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewFields(oldToNewFieldDef) {}
251 LogicalResult match(FieldRefOpInterface op)
const final {
256 return success(op.getComponent() == oldBaseVal && oldToNewFields.contains(op.getFieldName()));
259 void rewrite(FieldRefOpInterface op, PatternRewriter &rewriter)
const final {
260 rewriter.modifyOpInPlace(op, [
this, &op]() {
261 DestCloneOfSrcStructField newF = oldToNewFields.at(op.getFieldName());
262 op.setFieldName(newF.getSymName());
263 op.getComponentMutable().set(this->newBaseVal);
269 static FuncDefOp cloneWithFieldRefUpdate(std::unique_ptr<FieldRefRewriter> thisPat) {
271 FuncDefOp srcFuncClone = thisPat->funcRef.clone(mapper);
273 thisPat->funcRef = srcFuncClone;
274 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
276 MLIRContext *ctx = thisPat->getContext();
277 RewritePatternSet patterns(ctx, std::move(thisPat));
287 const StructInliner &data;
288 const DestToSrcToClonedSrcInDest &destToSrcToClone;
292 virtual FieldRefOpInterface getSelfRefField(CallOp callOp) = 0;
293 virtual void processCloneBeforeInlining(FuncDefOp func) {}
294 virtual ~ImplBase() =
default;
297 ImplBase(
const StructInliner &inliner,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
298 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
300 LogicalResult doInlining(FuncDefOp srcFunc, FuncDefOp destFunc) {
302 llvm::dbgs() <<
"[doInlining] SOURCE FUNCTION:\n";
304 llvm::dbgs() <<
"[doInlining] DESTINATION FUNCTION:\n";
308 InlinerInterface inliner(destFunc.getContext());
311 auto callHandler = [
this, &inliner, &srcFunc](CallOp callOp) {
314 assert(succeeded(callOpTarget));
315 if (callOpTarget->get() != srcFunc) {
316 return WalkResult::advance();
321 FieldRefOpInterface selfFieldRefOp = this->getSelfRefField(callOp);
322 if (!selfFieldRefOp) {
324 return WalkResult::interrupt();
330 FuncDefOp srcFuncClone =
331 FieldRefRewriter::cloneWithFieldRefUpdate(std::make_unique<FieldRefRewriter>(
333 this->destToSrcToClone.at(this->data.getDef(selfFieldRefOp))
335 this->processCloneBeforeInlining(srcFuncClone);
338 LogicalResult inlineCallRes =
339 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.
getBody(),
false);
340 if (failed(inlineCallRes)) {
342 return WalkResult::interrupt();
344 srcFuncClone.erase();
346 return WalkResult::skip();
349 auto fieldWriteHandler = [
this](FieldWriteOp writeOp) {
351 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
352 this->data.toDelete.fieldRefOps.push_back(writeOp);
354 return WalkResult::advance();
359 auto fieldReadHandler = [
this](FieldReadOp readOp) {
361 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
362 this->data.toDelete.fieldRefOps.push_back(readOp);
365 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
367 : WalkResult::advance();
370 WalkResult walkRes = destFunc.
getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
371 return TypeSwitch<Operation *, WalkResult>(op)
372 .Case<CallOp>(callHandler)
373 .Case<FieldWriteOp>(fieldWriteHandler)
374 .Case<FieldReadOp>(fieldReadHandler)
375 .Default([](Operation *) {
return WalkResult::advance(); });
378 return failure(walkRes.wasInterrupted());
382 class ConstrainImpl :
public ImplBase {
383 using ImplBase::ImplBase;
385 FieldRefOpInterface getSelfRefField(CallOp callOp)
override {
389 FieldRefOpInterface selfFieldRef = getFieldReadThatDefinesSelfValuePassedToConstrain(callOp);
391 selfFieldRef.
getComponent().getType() == this->data.destStruct.getType()) {
397 "\" to be passed a value read from a field in the current stuct."
404 class ComputeImpl :
public ImplBase {
405 using ImplBase::ImplBase;
407 FieldRefOpInterface getSelfRefField(CallOp callOp)
override {
413 FailureOr<FieldWriteOp> foundWrite =
415 return callOp.emitOpError().append(
"\"@", FUNC_NAME_COMPUTE,
"\" ");
417 return static_cast<FieldRefOpInterface
>(foundWrite.value_or(
nullptr));
420 void processCloneBeforeInlining(FuncDefOp func)
override {
424 func.
getBody().walk([
this](CreateStructOp newStructOp) {
425 if (newStructOp.getType() == this->data.srcStruct.getType()) {
426 this->data.toDelete.newStructOps.push_back(newStructOp);
435 DestToSrcToClonedSrcInDest cloneFields() {
436 DestToSrcToClonedSrcInDest destToSrcToClone;
438 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
439 StructType srcStructType = srcStruct.getType();
440 for (FieldDefOp destField : destStruct.getFieldDefs()) {
441 if (StructType destFieldType = llvm::dyn_cast<StructType>(destField.getType())) {
446 assert(unifications.empty());
448 toDelete.fieldDefs.push_back(destField);
451 SrcStructFieldToCloneInDest &srcToClone = destToSrcToClone.getOrInsertDefault(destField);
452 std::vector<FieldDefOp> srcFields = srcStruct.getFieldDefs();
453 if (srcFields.empty()) {
456 OpBuilder builder(destField);
457 std::string newNameBase =
459 for (FieldDefOp srcField : srcFields) {
460 DestCloneOfSrcStructField newF = llvm::cast<FieldDefOp>(builder.clone(*srcField));
461 newF.setName(builder.getStringAttr(newNameBase +
'+' + newF.getName()));
462 srcToClone[srcField.getSymNameAttr()] = newF;
464 destStructSymTable.insert(newF);
468 return destToSrcToClone;
472 inline LogicalResult inlineConstrainCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
473 return ConstrainImpl(*
this, destToSrcToClone)
474 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
478 inline LogicalResult inlineComputeCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
479 return ComputeImpl(*
this, destToSrcToClone)
480 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
485 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp
from, StructDefOp into
487 : tables(tbls), toDelete(opsToDelete), srcStruct(
from), destStruct(into) {}
489 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
491 llvm::dbgs() <<
"[StructInliner] merge " << srcStruct.getSymNameAttr() <<
" into "
492 << destStruct.getSymNameAttr() <<
'\n'
495 DestToSrcToClonedSrcInDest destToSrcToClone = cloneFields();
496 if (failed(inlineConstrainCall(destToSrcToClone)) ||
497 failed(inlineComputeCall(destToSrcToClone))) {
500 return destToSrcToClone;
507inline void splitFunctionParam(
508 FuncDefOp func,
unsigned paramIdx,
const SrcStructFieldToCloneInDest &nameToNewField
512 const SrcStructFieldToCloneInDest &newFields;
515 Impl(
unsigned paramIdx,
const SrcStructFieldToCloneInDest &nameToNewField)
516 : inputIdx(paramIdx), newFields(nameToNewField) {}
519 SmallVector<Type>
convertInputs(ArrayRef<Type> origTypes)
override {
520 SmallVector<Type> newTypes(origTypes);
521 auto it = newTypes.erase(newTypes.begin() + inputIdx);
522 for (
auto [_, newField] : newFields) {
523 newTypes.insert(it, newField.getType());
528 SmallVector<Type>
convertResults(ArrayRef<Type> origTypes)
override {
529 return SmallVector<Type>(origTypes);
534 SmallVector<Attribute> newAttrs(origAttrs.getValue());
535 newAttrs.insert(newAttrs.begin() + inputIdx, newFields.size() - 1, origAttrs[inputIdx]);
536 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
545 Value oldStructRef = entryBlock.getArgument(inputIdx);
549 llvm::StringMap<BlockArgument> fieldNameToNewArg;
550 Location loc = oldStructRef.getLoc();
551 unsigned idx = inputIdx;
552 for (
auto [fieldName, newField] : newFields) {
554 BlockArgument newArg = entryBlock.insertArgument(++idx, newField.getType(), loc);
555 fieldNameToNewArg[fieldName] = newArg;
560 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
561 if (FieldReadOp readOp = llvm::dyn_cast<FieldReadOp>(oldBlockArgUse.getOwner())) {
563 BlockArgument newArg = fieldNameToNewArg.at(readOp.
getFieldName());
564 rewriter.replaceAllUsesWith(readOp, newArg);
565 rewriter.eraseOp(readOp);
570 llvm::errs() <<
"Unexpected use of " << oldBlockArgUse.get() <<
" in "
571 << *oldBlockArgUse.getOwner() <<
'\n';
572 llvm_unreachable(
"Not yet implemented");
576 entryBlock.eraseArgument(inputIdx);
579 IRRewriter rewriter(func.getContext());
580 Impl(paramIdx, nameToNewField).convert(func, rewriter);
590 using InliningPlan = SmallVector<std::pair<StructDefOp, SmallVector<StructDefOp>>>;
592 static uint64_t complexity(FuncDefOp f) {
593 uint64_t complexity = 0;
594 f.
getBody().walk([&complexity](Operation *op) {
595 if (llvm::isa<felt::MulFeltOp>(op)) {
597 }
else if (
auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
599 }
else if (
auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
608 static FailureOr<FuncDefOp>
609 getIfStructConstrain(
const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
611 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
612 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
622 static inline StructDefOp getParentStruct(FuncDefOp func) {
625 assert(succeeded(currentNodeParentStruct));
626 return currentNodeParentStruct.value();
630 inline bool exceedsMaxComplexity(uint64_t check) {
631 return maxComplexity > 0 && check > maxComplexity;
636 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
648 WalkResult res = currentFunc.walk([](CallOp c) {
649 return getFieldReadThatDefinesSelfValuePassedToConstrain(c)
650 ? WalkResult::interrupt()
651 : WalkResult::advance();
657 return res.wasInterrupted();
664 inline FailureOr<InliningPlan>
665 makePlan(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
667 llvm::dbgs() <<
"Running InlineStructsPass with max complexity ";
668 if (maxComplexity == 0) {
669 llvm::dbgs() <<
"unlimited";
671 llvm::dbgs() << maxComplexity;
673 llvm::dbgs() <<
'\n';
676 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
688 for (
const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
689 LLVM_DEBUG(llvm::dbgs() <<
"\ncurrentNode = " << currentNode->toString());
690 if (!currentNode->isRealNode()) {
693 if (currentNode->isStructParam()) {
695 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
699 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
700 return reportLoc->emitError(
"Cannot inline structs with parameters.");
702 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
703 if (failed(currentFuncOpt)) {
706 FuncDefOp currentFunc = currentFuncOpt.value();
707 uint64_t currentComplexity = complexity(currentFunc);
709 if (exceedsMaxComplexity(currentComplexity)) {
710 complexityMemo[currentNode] = currentComplexity;
715 SmallVector<StructDefOp> successorsToMerge;
716 for (
const SymbolUseGraphNode *successor : currentNode->successorIter()) {
717 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"successor: " << successor->toString() <<
'\n');
719 auto memoResult = complexityMemo.find(successor);
720 if (memoResult == complexityMemo.end()) {
723 uint64_t sComplexity = memoResult->second;
725 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
726 "addition will overflow"
728 uint64_t potentialComplexity = currentComplexity + sComplexity;
729 if (!exceedsMaxComplexity(potentialComplexity)) {
730 currentComplexity = potentialComplexity;
731 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
732 assert(succeeded(successorFuncOpt));
733 FuncDefOp successorFunc = successorFuncOpt.value();
734 if (canInline(currentFunc, successorFunc)) {
735 successorsToMerge.push_back(getParentStruct(successorFunc));
739 complexityMemo[currentNode] = currentComplexity;
740 if (!successorsToMerge.empty()) {
741 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
745 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
746 llvm::dbgs() <<
"InlineStructsPass plan:\n";
747 for (
auto &[caller, callees] : retVal) {
748 llvm::dbgs().indent(2) <<
"inlining the following into \"" << caller.getSymName() <<
"\"\n";
749 for (StructDefOp c : callees) {
750 llvm::dbgs().indent(4) <<
"\"" << c.getSymName() <<
"\"\n";
753 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
763 static LogicalResult handleRemainingUses(
764 Operation *op, SymbolTableCollection &tables,
765 const DestToSrcToClonedSrcInDest &destToSrcToClone,
766 ArrayRef<FieldRefOpInterface> otherRefsToBeDeleted = {}
768 if (op->use_empty()) {
773 auto opWillBeDeleted = [&otherRefsToBeDeleted](Operation *op) ->
bool {
774 return std::find(otherRefsToBeDeleted.begin(), otherRefsToBeDeleted.end(), op) !=
775 otherRefsToBeDeleted.end();
779 llvm::dbgs() <<
"[handleRemainingUses] op: " << *op <<
'\n';
780 llvm::dbgs() <<
"[handleRemainingUses] in function: " << op->getParentOfType<FuncDefOp>()
783 for (OpOperand &
use : llvm::make_early_inc_range(op->getUses())) {
784 if (CallOp c = llvm::dyn_cast<CallOp>(
use.getOwner())) {
785 LLVM_DEBUG(llvm::dbgs() <<
"[handleRemainingUses] use in call: " << c <<
'\n');
786 unsigned argIdx =
use.getOperandNumber() - c.
getArgOperands().getBeginOperandIndex();
787 LLVM_DEBUG(llvm::dbgs() <<
"[handleRemainingUses] at index: " << argIdx <<
'\n');
790 if (failed(tgtFuncRes)) {
792 ->emitOpError(
"as argument to an unknown function is not supported by this pass.")
793 .attachNote(c.getLoc())
794 .append(
"used by this call");
796 FuncDefOp tgtFunc = tgtFuncRes->get();
797 LLVM_DEBUG(llvm::dbgs() <<
"[handleRemainingUses] call target: " << tgtFunc <<
'\n');
798 if (tgtFunc.isExternal()) {
802 ->emitOpError(
"as argument to a no-body free function is not supported by this pass.")
803 .attachNote(c.getLoc())
804 .append(
"used by this call");
807 FieldRefOpInterface paramFromField = TypeSwitch<Operation *, FieldRefOpInterface>(op)
808 .Case<FieldReadOp>([](
auto p) {
return p; })
809 .Case<CreateStructOp>([](
auto p) {
810 return findOpThatStoresSubcmp(p, [&p]() {
return p.emitOpError(); }).value_or(
nullptr);
811 }).Default([](Operation *p) {
812 llvm::errs() <<
"Encountered unexpected op: "
813 << (p ? p->getName().getStringRef() :
"<<null>>") <<
'\n';
814 llvm_unreachable(
"Unexpected op kind");
818 llvm::dbgs() <<
"[handleRemainingUses] field ref op for param: "
822 if (!paramFromField) {
825 const SrcStructFieldToCloneInDest &newFields =
826 destToSrcToClone.at(getDef(tables, paramFromField));
828 llvm::dbgs() <<
"[handleRemainingUses] fields to split: "
833 splitFunctionParam(tgtFunc, argIdx, newFields);
835 llvm::dbgs() <<
"[handleRemainingUses] UPDATED call target: " << tgtFunc <<
'\n';
836 llvm::dbgs() <<
"[handleRemainingUses] UPDATED call target type: "
843 OpBuilder builder(c);
844 SmallVector<Value> splitArgs;
848 for (
auto [origName, newFieldRef] : newFields) {
849 splitArgs.push_back(builder.create<FieldReadOp>(
850 c.getLoc(), newFieldRef.getType(), originalBaseVal, newFieldRef.getNameAttr()
856 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
859 c.replaceAllUsesWith(builder.create<CallOp>(
866 llvm::dbgs() <<
"[handleRemainingUses] UPDATED function: "
867 << op->getParentOfType<FuncDefOp>() <<
'\n';
870 Operation *user =
use.getOwner();
872 if (!opWillBeDeleted(user)) {
873 return op->emitOpError()
875 "with use in '", user->getName().getStringRef(),
876 "' is not (currently) supported by this pass."
878 .attachNote(user->getLoc())
879 .append(
"used by this call");
884 if (!op->use_empty()) {
885 for (Operation *user : op->getUsers()) {
886 if (!opWillBeDeleted(user)) {
887 llvm::errs() <<
"Op has remaining use(s) that could not be removed: " << *op <<
'\n';
888 llvm_unreachable(
"Expected all uses to be removed");
895 inline static LogicalResult finalizeStruct(
896 SymbolTableCollection &tables, StructDefOp caller, PendingErasure &&toDelete,
897 DestToSrcToClonedSrcInDest &&destToSrcToClone
900 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
901 llvm::dbgs() << caller <<
'\n';
906 combineReadChain(readOp, tables, destToSrcToClone);
908 auto res = caller.
getComputeFuncOp().walk([&tables, &destToSrcToClone](FieldReadOp readOp) {
909 combineReadChain(readOp, tables, destToSrcToClone);
910 LogicalResult res = combineNewThenReadChain(readOp, tables, destToSrcToClone);
911 return failed(res) ? WalkResult::interrupt() : WalkResult::advance();
913 if (res.wasInterrupted()) {
918 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
919 llvm::dbgs() << caller <<
'\n';
920 llvm::dbgs() <<
"[finalizeStruct] ops marked for deletion:\n";
921 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
922 llvm::dbgs().indent(2) << op <<
'\n';
924 for (CreateStructOp op : toDelete.newStructOps) {
925 llvm::dbgs().indent(2) << op <<
'\n';
927 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
928 llvm::dbgs().indent(2) << op <<
'\n';
934 for (CreateStructOp op : toDelete.newStructOps) {
935 if (failed(handleRemainingUses(op, tables, destToSrcToClone, toDelete.fieldRefOps))) {
941 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
942 if (failed(handleRemainingUses(op, tables, destToSrcToClone))) {
947 for (CreateStructOp op : toDelete.newStructOps) {
951 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
952 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
953 assert(op.getParentOp() == caller);
954 callerSymTab.erase(op);
961 void runOnOperation()
override {
962 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
965 SymbolTableCollection tables;
966 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
972 for (
auto &[caller, callees] : plan.value()) {
975 PendingErasure toDelete;
977 DestToSrcToClonedSrcInDest aggregateReplacements;
979 for (StructDefOp toInline : callees) {
980 FailureOr<DestToSrcToClonedSrcInDest> res =
981 StructInliner(tables, toDelete, toInline, caller).doInline();
987 for (
auto &[k, v] : res.value()) {
988 assert(!aggregateReplacements.contains(k) &&
"duplicate not possible");
989 aggregateReplacements[k] = std::move(v);
993 LogicalResult finalizeResult =
994 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
995 if (failed(finalizeResult)) {
1006 return std::make_unique<InlineStructsPass>();
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::llvm::StringRef getFieldName()
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the FieldRefOp.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
::mlir::OperandRangeRange getMapOperands()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::Operation::operand_range getArgOperands()
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FunctionType getFunctionType()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
::mlir::Region & getBody()
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
std::string toStringOne(const T &value)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener)
A fast walk-based pattern rewrite driver.