34#include <mlir/IR/BuiltinOps.h>
35#include <mlir/Transforms/InliningUtils.h>
36#include <mlir/Transforms/WalkPatternRewriteDriver.h>
38#include <llvm/ADT/PostOrderIterator.h>
39#include <llvm/ADT/SmallVector.h>
40#include <llvm/ADT/StringMap.h>
41#include <llvm/ADT/TypeSwitch.h>
42#include <llvm/Support/Debug.h>
47#define GEN_PASS_DECL_INLINESTRUCTSPASS
48#define GEN_PASS_DEF_INLINESTRUCTSPASS
57#define DEBUG_TYPE "llzk-inline-structs"
66using SrcStructFieldToCloneInDest = std::map<StringRef, DestCloneOfSrcStructField>;
69using DestToSrcToClonedSrcInDest =
70 DenseMap<DestFieldWithSrcStructType, SrcStructFieldToCloneInDest>;
78 llvm_unreachable(
"expected \"compute\" or \"constrain\" function");
95 auto srcToClone = destToSrcToClone.find(getDef(tables, destFieldRefOp));
96 if (srcToClone == destToSrcToClone.end()) {
99 SrcStructFieldToCloneInDest oldToNewFields = srcToClone->second;
100 auto resNewField = oldToNewFields.find(readOp.
getFieldName());
101 if (resNewField == oldToNewFields.end()) {
106 OpBuilder builder(readOp);
108 readOp.getLoc(), readOp.getType(), destFieldRefOp.
getComponent(),
109 resNewField->second.getNameAttr()
111 readOp.replaceAllUsesWith(newRead.getOperation());
129bool combineReadChain(
131 const DestToSrcToClonedSrcInDest &destToSrcToClone
134 llvm::dyn_cast_if_present<FieldReadOp>(readOp.
getComponent().getDefiningOp());
135 if (!readThatDefinesBaseComponent) {
138 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
143FailureOr<FieldWriteOp>
144findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
146 for (Operation *user : writtenValue.getUsers()) {
147 if (
FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(user)) {
149 if (writeOp.getVal() == writtenValue) {
152 auto diag = emitError().append(
"result should not be written to more than one field.");
153 diag.attachNote(foundWrite.getLoc()).append(
"written here");
154 diag.attachNote(writeOp.getLoc()).append(
"written here");
157 foundWrite = writeOp;
164 return emitError().append(
"result should be written to a field.");
185LogicalResult combineNewThenReadChain(
187 const DestToSrcToClonedSrcInDest &destToSrcToClone
190 llvm::dyn_cast_if_present<CreateStructOp>(readOp.
getComponent().getDefiningOp());
191 if (!createThatDefinesBaseComponent) {
194 FailureOr<FieldWriteOp> foundWrite =
195 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
196 return createThatDefinesBaseComponent.emitOpError();
198 if (failed(foundWrite)) {
201 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
204inline FieldReadOp getFieldReadThatDefinesSelfValuePassedToConstrain(
CallOp callOp) {
206 return llvm::dyn_cast_if_present<FieldReadOp>(selfArgFromCall.getDefiningOp());
211struct PendingErasure {
212 SmallVector<FieldRefOpInterface> fieldRefOps;
213 SmallVector<CreateStructOp> newStructOps;
214 SmallVector<DestFieldWithSrcStructType> fieldDefs;
218 SymbolTableCollection &tables;
219 PendingErasure &toDelete;
221 StructDefOp srcStruct;
223 StructDefOp destStruct;
225 inline FieldDefOp getDef(FieldRefOpInterface fRef)
const { return ::getDef(tables, fRef); }
233 class FieldRefRewriter final :
public OpInterfaceRewritePattern<FieldRefOpInterface> {
241 const SrcStructFieldToCloneInDest &oldToNewFields;
245 FuncDefOp originalFunc, Value newRefBase,
246 const SrcStructFieldToCloneInDest &oldToNewFieldDef
248 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
249 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewFields(oldToNewFieldDef) {}
251 LogicalResult match(FieldRefOpInterface op)
const final {
256 return success(op.getComponent() == oldBaseVal && oldToNewFields.contains(op.getFieldName()));
259 void rewrite(FieldRefOpInterface op, PatternRewriter &rewriter)
const final {
260 rewriter.modifyOpInPlace(op, [
this, &op]() {
261 DestCloneOfSrcStructField newF = oldToNewFields.at(op.getFieldName());
262 op.setFieldName(newF.getSymName());
263 op.getComponentMutable().set(this->newBaseVal);
269 static FuncDefOp cloneWithFieldRefUpdate(std::unique_ptr<FieldRefRewriter> thisPat) {
271 FuncDefOp srcFuncClone = thisPat->funcRef.clone(mapper);
273 thisPat->funcRef = srcFuncClone;
274 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
276 MLIRContext *ctx = thisPat->getContext();
277 RewritePatternSet patterns(ctx, std::move(thisPat));
278 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
287 const StructInliner &data;
288 const DestToSrcToClonedSrcInDest &destToSrcToClone;
292 virtual FieldRefOpInterface getSelfRefField(CallOp callOp) = 0;
293 virtual void processCloneBeforeInlining(FuncDefOp func) {}
294 virtual ~ImplBase() =
default;
297 ImplBase(
const StructInliner &inliner,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
298 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
300 LogicalResult doInlining(FuncDefOp srcFunc, FuncDefOp destFunc) {
302 llvm::dbgs() <<
"[doInlining] SOURCE FUNCTION:\n";
304 llvm::dbgs() <<
"[doInlining] DESTINATION FUNCTION:\n";
308 InlinerInterface inliner(destFunc.getContext());
311 auto callHandler = [
this, &inliner, &srcFunc](CallOp callOp) {
314 assert(succeeded(callOpTarget));
315 if (callOpTarget->get() != srcFunc) {
316 return WalkResult::advance();
321 FieldRefOpInterface selfFieldRefOp = this->getSelfRefField(callOp);
322 if (!selfFieldRefOp) {
324 return WalkResult::interrupt();
330 FuncDefOp srcFuncClone = FieldRefRewriter::cloneWithFieldRefUpdate(
331 std::make_unique<FieldRefRewriter>(
333 this->destToSrcToClone.at(this->data.getDef(selfFieldRefOp))
336 this->processCloneBeforeInlining(srcFuncClone);
339 LogicalResult inlineCallRes =
340 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.
getBody(),
false);
341 if (failed(inlineCallRes)) {
343 return WalkResult::interrupt();
345 srcFuncClone.erase();
347 return WalkResult::skip();
350 auto fieldWriteHandler = [
this](FieldWriteOp writeOp) {
352 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
353 this->data.toDelete.fieldRefOps.push_back(writeOp);
355 return WalkResult::advance();
360 auto fieldReadHandler = [
this](FieldReadOp readOp) {
362 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
363 this->data.toDelete.fieldRefOps.push_back(readOp);
366 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
368 : WalkResult::advance();
371 WalkResult walkRes = destFunc.
getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
372 return TypeSwitch<Operation *, WalkResult>(op)
373 .Case<CallOp>(callHandler)
374 .Case<FieldWriteOp>(fieldWriteHandler)
375 .Case<FieldReadOp>(fieldReadHandler)
376 .Default([](Operation *) {
return WalkResult::advance(); });
379 return failure(walkRes.wasInterrupted());
383 class ConstrainImpl :
public ImplBase {
384 using ImplBase::ImplBase;
386 FieldRefOpInterface getSelfRefField(CallOp callOp)
override {
390 FieldRefOpInterface selfFieldRef = getFieldReadThatDefinesSelfValuePassedToConstrain(callOp);
392 selfFieldRef.
getComponent().getType() == this->data.destStruct.getType()) {
398 "\" to be passed a value read from a field in the current stuct."
405 class ComputeImpl :
public ImplBase {
406 using ImplBase::ImplBase;
408 FieldRefOpInterface getSelfRefField(CallOp callOp)
override {
414 FailureOr<FieldWriteOp> foundWrite =
416 return callOp.emitOpError().append(
"\"@", FUNC_NAME_COMPUTE,
"\" ");
418 return static_cast<FieldRefOpInterface
>(foundWrite.value_or(
nullptr));
421 void processCloneBeforeInlining(FuncDefOp func)
override {
425 func.
getBody().walk([
this](CreateStructOp newStructOp) {
426 if (newStructOp.getType() == this->data.srcStruct.getType()) {
427 this->data.toDelete.newStructOps.push_back(newStructOp);
436 DestToSrcToClonedSrcInDest cloneFields() {
437 DestToSrcToClonedSrcInDest destToSrcToClone;
439 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
440 StructType srcStructType = srcStruct.getType();
441 for (FieldDefOp destField : destStruct.getFieldDefs()) {
442 if (StructType destFieldType = llvm::dyn_cast<StructType>(destField.getType())) {
447 assert(unifications.empty());
449 toDelete.fieldDefs.push_back(destField);
452 SrcStructFieldToCloneInDest &srcToClone = destToSrcToClone[destField];
453 std::vector<FieldDefOp> srcFields = srcStruct.getFieldDefs();
454 if (srcFields.empty()) {
457 OpBuilder builder(destField);
458 std::string newNameBase =
460 for (FieldDefOp srcField : srcFields) {
461 DestCloneOfSrcStructField newF = llvm::cast<FieldDefOp>(builder.clone(*srcField));
462 newF.setName(builder.getStringAttr(newNameBase +
'+' + newF.getName()));
463 srcToClone[srcField.getSymNameAttr()] = newF;
465 destStructSymTable.insert(newF);
469 return destToSrcToClone;
473 inline LogicalResult inlineConstrainCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
474 return ConstrainImpl(*
this, destToSrcToClone)
475 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
479 inline LogicalResult inlineComputeCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
480 return ComputeImpl(*
this, destToSrcToClone)
481 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
486 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp
from, StructDefOp into
488 : tables(tbls), toDelete(opsToDelete), srcStruct(
from), destStruct(into) {}
490 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
492 llvm::dbgs() <<
"[StructInliner] merge " << srcStruct.getSymNameAttr() <<
" into "
493 << destStruct.getSymNameAttr() <<
'\n'
496 DestToSrcToClonedSrcInDest destToSrcToClone = cloneFields();
497 if (failed(inlineConstrainCall(destToSrcToClone)) ||
498 failed(inlineComputeCall(destToSrcToClone))) {
501 return destToSrcToClone;
508inline void splitFunctionParam(
509 FuncDefOp func,
unsigned paramIdx,
const SrcStructFieldToCloneInDest &nameToNewField
513 const SrcStructFieldToCloneInDest &newFields;
516 Impl(
unsigned paramIdx,
const SrcStructFieldToCloneInDest &nameToNewField)
517 : inputIdx(paramIdx), newFields(nameToNewField) {}
520 SmallVector<Type>
convertInputs(ArrayRef<Type> origTypes)
override {
521 SmallVector<Type> newTypes(origTypes);
522 auto it = newTypes.erase(newTypes.begin() + inputIdx);
523 for (
auto [_, newField] : newFields) {
524 newTypes.insert(it, newField.getType());
529 SmallVector<Type>
convertResults(ArrayRef<Type> origTypes)
override {
530 return SmallVector<Type>(origTypes);
535 SmallVector<Attribute> newAttrs(origAttrs.getValue());
536 newAttrs.insert(newAttrs.begin() + inputIdx, newFields.size() - 1, origAttrs[inputIdx]);
537 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
546 Value oldStructRef = entryBlock.getArgument(inputIdx);
550 llvm::StringMap<BlockArgument> fieldNameToNewArg;
551 Location loc = oldStructRef.getLoc();
552 unsigned idx = inputIdx;
553 for (
auto [fieldName, newField] : newFields) {
555 BlockArgument newArg = entryBlock.insertArgument(++idx, newField.getType(), loc);
556 fieldNameToNewArg[fieldName] = newArg;
561 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
562 if (FieldReadOp readOp = llvm::dyn_cast<FieldReadOp>(oldBlockArgUse.getOwner())) {
564 BlockArgument newArg = fieldNameToNewArg.at(readOp.
getFieldName());
565 rewriter.replaceAllUsesWith(readOp, newArg);
566 rewriter.eraseOp(readOp);
571 llvm::errs() <<
"Unexpected use of " << oldBlockArgUse.get() <<
" in "
572 << *oldBlockArgUse.getOwner() <<
'\n';
573 llvm_unreachable(
"Not yet implemented");
577 entryBlock.eraseArgument(inputIdx);
580 IRRewriter rewriter(func.getContext());
581 Impl(paramIdx, nameToNewField).convert(func, rewriter);
591 using InliningPlan = SmallVector<std::pair<StructDefOp, SmallVector<StructDefOp>>>;
593 static uint64_t complexity(FuncDefOp f) {
594 uint64_t complexity = 0;
595 f.
getBody().walk([&complexity](Operation *op) {
596 if (llvm::isa<felt::MulFeltOp>(op)) {
598 }
else if (
auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
600 }
else if (
auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
609 static FailureOr<FuncDefOp>
610 getIfStructConstrain(
const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
612 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
613 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
623 static inline StructDefOp getParentStruct(FuncDefOp func) {
626 assert(succeeded(currentNodeParentStruct));
627 return currentNodeParentStruct.value();
631 inline bool exceedsMaxComplexity(uint64_t check) {
632 return maxComplexity > 0 && check > maxComplexity;
637 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
649 WalkResult res = currentFunc.walk([](CallOp c) {
650 return getFieldReadThatDefinesSelfValuePassedToConstrain(c)
651 ? WalkResult::interrupt()
652 : WalkResult::advance();
658 return res.wasInterrupted();
665 inline FailureOr<InliningPlan>
666 makePlan(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
668 llvm::dbgs() <<
"Running InlineStructsPass with max complexity ";
669 if (maxComplexity == 0) {
670 llvm::dbgs() <<
"unlimited";
672 llvm::dbgs() << maxComplexity;
674 llvm::dbgs() <<
'\n';
677 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
689 for (
const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
690 LLVM_DEBUG(llvm::dbgs() <<
"\ncurrentNode = " << currentNode->toString());
691 if (!currentNode->isRealNode()) {
694 if (currentNode->isStructParam()) {
696 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
700 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
701 return reportLoc->emitError(
"Cannot inline structs with parameters.");
703 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
704 if (failed(currentFuncOpt)) {
707 FuncDefOp currentFunc = currentFuncOpt.value();
708 uint64_t currentComplexity = complexity(currentFunc);
710 if (exceedsMaxComplexity(currentComplexity)) {
711 complexityMemo[currentNode] = currentComplexity;
716 SmallVector<StructDefOp> successorsToMerge;
717 for (
const SymbolUseGraphNode *successor : currentNode->successorIter()) {
718 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"successor: " << successor->toString() <<
'\n');
720 auto memoResult = complexityMemo.find(successor);
721 if (memoResult == complexityMemo.end()) {
724 uint64_t sComplexity = memoResult->second;
726 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
727 "addition will overflow"
729 uint64_t potentialComplexity = currentComplexity + sComplexity;
730 if (!exceedsMaxComplexity(potentialComplexity)) {
731 currentComplexity = potentialComplexity;
732 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
733 assert(succeeded(successorFuncOpt));
734 FuncDefOp successorFunc = successorFuncOpt.value();
735 if (canInline(currentFunc, successorFunc)) {
736 successorsToMerge.push_back(getParentStruct(successorFunc));
740 complexityMemo[currentNode] = currentComplexity;
741 if (!successorsToMerge.empty()) {
742 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
746 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
747 llvm::dbgs() <<
"InlineStructsPass plan:\n";
748 for (
auto &[caller, callees] : retVal) {
749 llvm::dbgs().indent(2) <<
"inlining the following into \"" << caller.getSymName() <<
"\"\n";
750 for (StructDefOp c : callees) {
751 llvm::dbgs().indent(4) <<
"\"" << c.getSymName() <<
"\"\n";
754 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
764 static LogicalResult handleRemainingUses(
765 Operation *op, SymbolTableCollection &tables,
766 const DestToSrcToClonedSrcInDest &destToSrcToClone,
767 ArrayRef<FieldRefOpInterface> otherRefsToBeDeleted = {}
769 if (op->use_empty()) {
774 auto opWillBeDeleted = [&otherRefsToBeDeleted](Operation *otherOp) ->
bool {
775 return std::find(otherRefsToBeDeleted.begin(), otherRefsToBeDeleted.end(), otherOp) !=
776 otherRefsToBeDeleted.end();
780 llvm::dbgs() <<
"[handleRemainingUses] op: " << *op <<
'\n';
781 llvm::dbgs() <<
"[handleRemainingUses] in function: " << op->getParentOfType<FuncDefOp>()
784 for (OpOperand &
use : llvm::make_early_inc_range(op->getUses())) {
785 if (CallOp c = llvm::dyn_cast<CallOp>(
use.getOwner())) {
786 LLVM_DEBUG(llvm::dbgs() <<
"[handleRemainingUses] use in call: " << c <<
'\n');
787 unsigned argIdx =
use.getOperandNumber() - c.
getArgOperands().getBeginOperandIndex();
788 LLVM_DEBUG(llvm::dbgs() <<
"[handleRemainingUses] at index: " << argIdx <<
'\n');
791 if (failed(tgtFuncRes)) {
793 ->emitOpError(
"as argument to an unknown function is not supported by this pass.")
794 .attachNote(c.getLoc())
795 .append(
"used by this call");
797 FuncDefOp tgtFunc = tgtFuncRes->get();
798 LLVM_DEBUG(llvm::dbgs() <<
"[handleRemainingUses] call target: " << tgtFunc <<
'\n');
799 if (tgtFunc.isExternal()) {
803 ->emitOpError(
"as argument to a no-body free function is not supported by this pass.")
804 .attachNote(c.getLoc())
805 .append(
"used by this call");
808 FieldRefOpInterface paramFromField = TypeSwitch<Operation *, FieldRefOpInterface>(op)
809 .Case<FieldReadOp>([](
auto p) {
return p; })
810 .Case<CreateStructOp>([](
auto p) {
811 return findOpThatStoresSubcmp(p, [&p]() {
return p.emitOpError(); }).value_or(
nullptr);
812 }).Default([](Operation *p) {
813 llvm::errs() <<
"Encountered unexpected op: "
814 << (p ? p->getName().getStringRef() :
"<<null>>") <<
'\n';
815 llvm_unreachable(
"Unexpected op kind");
819 llvm::dbgs() <<
"[handleRemainingUses] field ref op for param: "
823 if (!paramFromField) {
826 const SrcStructFieldToCloneInDest &newFields =
827 destToSrcToClone.at(getDef(tables, paramFromField));
829 llvm::dbgs() <<
"[handleRemainingUses] fields to split: "
834 splitFunctionParam(tgtFunc, argIdx, newFields);
836 llvm::dbgs() <<
"[handleRemainingUses] UPDATED call target: " << tgtFunc <<
'\n';
837 llvm::dbgs() <<
"[handleRemainingUses] UPDATED call target type: "
844 OpBuilder builder(c);
845 SmallVector<Value> splitArgs;
849 for (
auto [origName, newFieldRef] : newFields) {
850 splitArgs.push_back(builder.create<FieldReadOp>(
851 c.getLoc(), newFieldRef.getType(), originalBaseVal, newFieldRef.getNameAttr()
857 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
860 c.replaceAllUsesWith(builder.create<CallOp>(
867 llvm::dbgs() <<
"[handleRemainingUses] UPDATED function: "
868 << op->getParentOfType<FuncDefOp>() <<
'\n';
871 Operation *user =
use.getOwner();
873 if (!opWillBeDeleted(user)) {
874 return op->emitOpError()
876 "with use in '", user->getName().getStringRef(),
877 "' is not (currently) supported by this pass."
879 .attachNote(user->getLoc())
880 .append(
"used by this call");
885 if (!op->use_empty()) {
886 for (Operation *user : op->getUsers()) {
887 if (!opWillBeDeleted(user)) {
888 llvm::errs() <<
"Op has remaining use(s) that could not be removed: " << *op <<
'\n';
889 llvm_unreachable(
"Expected all uses to be removed");
896 inline static LogicalResult finalizeStruct(
897 SymbolTableCollection &tables, StructDefOp caller, PendingErasure &&toDelete,
898 DestToSrcToClonedSrcInDest &&destToSrcToClone
901 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
902 llvm::dbgs() << caller <<
'\n';
907 combineReadChain(readOp, tables, destToSrcToClone);
909 auto res = caller.
getComputeFuncOp().walk([&tables, &destToSrcToClone](FieldReadOp readOp) {
910 combineReadChain(readOp, tables, destToSrcToClone);
911 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
912 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
914 if (res.wasInterrupted()) {
919 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
920 llvm::dbgs() << caller <<
'\n';
921 llvm::dbgs() <<
"[finalizeStruct] ops marked for deletion:\n";
922 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
923 llvm::dbgs().indent(2) << op <<
'\n';
925 for (CreateStructOp op : toDelete.newStructOps) {
926 llvm::dbgs().indent(2) << op <<
'\n';
928 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
929 llvm::dbgs().indent(2) << op <<
'\n';
935 for (CreateStructOp op : toDelete.newStructOps) {
936 if (failed(handleRemainingUses(op, tables, destToSrcToClone, toDelete.fieldRefOps))) {
942 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
943 if (failed(handleRemainingUses(op, tables, destToSrcToClone))) {
948 for (CreateStructOp op : toDelete.newStructOps) {
952 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
953 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
954 assert(op.getParentOp() == caller);
955 callerSymTab.erase(op);
962 void runOnOperation()
override {
963 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
966 SymbolTableCollection tables;
967 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
973 for (
auto &[caller, callees] : plan.value()) {
976 PendingErasure toDelete;
978 DestToSrcToClonedSrcInDest aggregateReplacements;
980 for (StructDefOp toInline : callees) {
981 FailureOr<DestToSrcToClonedSrcInDest> res =
982 StructInliner(tables, toDelete, toInline, caller).doInline();
988 for (
auto &[k, v] : res.value()) {
989 assert(!aggregateReplacements.contains(k) &&
"duplicate not possible");
990 aggregateReplacements[k] = std::move(v);
994 LogicalResult finalizeResult =
995 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
996 if (failed(finalizeResult)) {
1007 return std::make_unique<InlineStructsPass>();
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::llvm::StringRef getFieldName()
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the FieldRefOp.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
::mlir::Operation::operand_range getArgOperands()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FunctionType getFunctionType()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
::mlir::Region & getBody()
std::string toStringOne(const T &value)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)