LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKInlineStructsPass.cpp
Go to the documentation of this file.
1//===-- LLZKInlineStructsPass.cpp -------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
19//===----------------------------------------------------------------------===//
31#include "llzk/Util/Debug.h"
34
35#include <mlir/IR/BuiltinOps.h>
36#include <mlir/Transforms/InliningUtils.h>
37
38#include <llvm/ADT/PostOrderIterator.h>
39#include <llvm/ADT/SmallVector.h>
40#include <llvm/ADT/StringMap.h>
41#include <llvm/ADT/TypeSwitch.h>
42#include <llvm/Support/Debug.h>
43
44// Include the generated base pass class definitions.
45namespace llzk {
46// the *DECL* macro is required when a pass has options to declare the option struct
47#define GEN_PASS_DECL_INLINESTRUCTSPASS
48#define GEN_PASS_DEF_INLINESTRUCTSPASS
50} // namespace llzk
51
52using namespace mlir;
53using namespace llzk;
54using namespace llzk::component;
55using namespace llzk::function;
56
57#define DEBUG_TYPE "llzk-inline-structs"
58
59namespace {
61using DestFieldWithSrcStructType = FieldDefOp;
62using DestCloneOfSrcStructField = FieldDefOp;
66using SrcStructFieldToCloneInDest = std::map<StringRef, DestCloneOfSrcStructField>;
69using DestToSrcToClonedSrcInDest =
70 DenseMap<DestFieldWithSrcStructType, SrcStructFieldToCloneInDest>;
71
72Value getSelfValue(FuncDefOp f) {
73 if (f.nameIsCompute()) {
74 return f.getSelfValueFromCompute();
75 } else if (f.nameIsConstrain()) {
77 } else {
78 llvm_unreachable("expected \"compute\" or \"constrain\" function");
79 }
80}
81
82inline FieldDefOp getDef(SymbolTableCollection &tables, FieldRefOpInterface fRef) {
83 auto r = fRef.getFieldDefOp(tables);
84 assert(succeeded(r));
85 return r->get();
87
89
91bool combineHelper(
92 FieldReadOp readOp, SymbolTableCollection &tables,
93 const DestToSrcToClonedSrcInDest &destToSrcToClone, FieldRefOpInterface destFieldRefOp
94) {
95 auto srcToClone = destToSrcToClone.find(getDef(tables, destFieldRefOp));
96 if (srcToClone == destToSrcToClone.end()) {
97 return false;
98 }
99 SrcStructFieldToCloneInDest oldToNewFields = srcToClone->second;
100 auto resNewField = oldToNewFields.find(readOp.getFieldName());
101 if (resNewField == oldToNewFields.end()) {
102 return false;
103 }
104
105 // Replace this FieldReadOp with a new one that targets the cloned field.
106 OpBuilder builder(readOp);
107 FieldReadOp newRead = builder.create<FieldReadOp>(
108 readOp.getLoc(), readOp.getType(), destFieldRefOp.getComponent(),
109 resNewField->second.getNameAttr()
110 );
111 readOp.replaceAllUsesWith(newRead.getOperation());
112 readOp.erase(); // delete the original FieldReadOp
113 return true;
114}
115
129bool combineReadChain(
130 FieldReadOp readOp, SymbolTableCollection &tables,
131 const DestToSrcToClonedSrcInDest &destToSrcToClone
132) {
133 FieldReadOp readThatDefinesBaseComponent =
134 llvm::dyn_cast_if_present<FieldReadOp>(readOp.getComponent().getDefiningOp());
135 if (!readThatDefinesBaseComponent) {
136 return false;
137 }
138 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
139}
140
143FailureOr<FieldWriteOp>
144findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
145 FieldWriteOp foundWrite = nullptr;
146 for (Operation *user : writtenValue.getUsers()) {
147 if (FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(user)) {
148 // Find the write op that stores the created value
149 if (writeOp.getVal() == writtenValue) {
150 if (foundWrite) {
151 // Note: There is no reason for a subcomponent to be stored to more than one field.
152 auto diag = emitError().append("result should not be written to more than one field.");
153 diag.attachNote(foundWrite.getLoc()).append("written here");
154 diag.attachNote(writeOp.getLoc()).append("written here");
155 return diag;
156 } else {
157 foundWrite = writeOp;
158 }
159 }
160 }
161 }
162 if (!foundWrite) {
163 // Note: There is no reason to construct a subcomponent and not store it to a field.
164 return emitError().append("result should be written to a field.");
165 }
166 return foundWrite;
167}
168
185LogicalResult combineNewThenReadChain(
186 FieldReadOp readOp, SymbolTableCollection &tables,
187 const DestToSrcToClonedSrcInDest &destToSrcToClone
188) {
189 CreateStructOp createThatDefinesBaseComponent =
190 llvm::dyn_cast_if_present<CreateStructOp>(readOp.getComponent().getDefiningOp());
191 if (!createThatDefinesBaseComponent) {
192 return success(); // No error. The pattern simply doesn't match.
193 }
194 FailureOr<FieldWriteOp> foundWrite =
195 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
196 return createThatDefinesBaseComponent.emitOpError();
197 });
198 if (failed(foundWrite)) {
199 return failure(); // error already printed within findOpThatStoresSubcmp()
200 }
201 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
202}
203
204inline FieldReadOp getFieldReadThatDefinesSelfValuePassedToConstrain(CallOp callOp) {
205 Value selfArgFromCall = callOp.getSelfValueFromConstrain();
206 return llvm::dyn_cast_if_present<FieldReadOp>(selfArgFromCall.getDefiningOp());
207}
208
211struct PendingErasure {
212 SmallVector<FieldRefOpInterface> fieldRefOps;
213 SmallVector<CreateStructOp> newStructOps;
214 SmallVector<DestFieldWithSrcStructType> fieldDefs;
215};
216
217class StructInliner {
218 SymbolTableCollection &tables;
219 PendingErasure &toDelete;
221 StructDefOp srcStruct;
223 StructDefOp destStruct;
224
225 inline FieldDefOp getDef(FieldRefOpInterface fRef) const { return ::getDef(tables, fRef); }
226
227 // Update field read/write ops that target the "self" value of the FuncDefOp plus some key in
228 // `oldToNewFieldDef` to instead target the new base Value provided to the constructor plus the
229 // mapped Value from `oldToNewFieldDef`.
230 // Example:
231 // old: %1 = struct.readf %0[@f1] : <@Component1A>, !felt.type
232 // new: %1 = struct.readf %self[@"f2:!s<@Component1A>+f1"] : <@Component1B>, !felt.type
233 class FieldRefRewriter final : public OpInterfaceRewritePattern<FieldRefOpInterface> {
236 FuncDefOp funcRef;
238 Value oldBaseVal;
240 Value newBaseVal;
241 const SrcStructFieldToCloneInDest &oldToNewFields;
242
243 public:
244 FieldRefRewriter(
245 FuncDefOp originalFunc, Value newRefBase,
246 const SrcStructFieldToCloneInDest &oldToNewFieldDef
247 )
248 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
249 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewFields(oldToNewFieldDef) {}
250
251 LogicalResult match(FieldRefOpInterface op) const final {
252 assert(oldBaseVal); // ensure it's used via `cloneWithFieldRefUpdate()` only
253 // Check if the FieldRef accesses a field of "self" within the `oldToNewFields` map.
254 // Per `cloneWithFieldRefUpdate()`, `oldBaseVal` is the "self" value of `funcRef` so
255 // check for a match there and then check that the referenced field name is in the map.
256 return success(op.getComponent() == oldBaseVal && oldToNewFields.contains(op.getFieldName()));
257 }
258
259 void rewrite(FieldRefOpInterface op, PatternRewriter &rewriter) const final {
260 rewriter.modifyOpInPlace(op, [this, &op]() {
261 DestCloneOfSrcStructField newF = oldToNewFields.at(op.getFieldName());
262 op.setFieldName(newF.getSymName());
263 op.getComponentMutable().set(this->newBaseVal);
264 });
265 }
266
269 static FuncDefOp cloneWithFieldRefUpdate(std::unique_ptr<FieldRefRewriter> thisPat) {
270 IRMapping mapper;
271 FuncDefOp srcFuncClone = thisPat->funcRef.clone(mapper);
272 // Update some data in the `FieldRefRewriter` instance before moving it.
273 thisPat->funcRef = srcFuncClone;
274 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
275 // Run the rewriter to replace read/write ops
276 MLIRContext *ctx = thisPat->getContext();
277 RewritePatternSet patterns(ctx, std::move(thisPat));
278 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
279
280 return srcFuncClone;
281 }
282 };
283
285 class ImplBase {
286 protected:
287 const StructInliner &data;
288 const DestToSrcToClonedSrcInDest &destToSrcToClone;
289
292 virtual FieldRefOpInterface getSelfRefField(CallOp callOp) = 0;
293 virtual void processCloneBeforeInlining(FuncDefOp func) {}
294 virtual ~ImplBase() = default;
295
296 public:
297 ImplBase(const StructInliner &inliner, const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
298 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
299
300 LogicalResult doInlining(FuncDefOp srcFunc, FuncDefOp destFunc) {
301 LLVM_DEBUG({
302 llvm::dbgs() << "[doInlining] SOURCE FUNCTION:\n";
303 srcFunc.dump();
304 llvm::dbgs() << "[doInlining] DESTINATION FUNCTION:\n";
305 destFunc.dump();
306 });
307
308 InlinerInterface inliner(destFunc.getContext());
309
311 auto callHandler = [this, &inliner, &srcFunc](CallOp callOp) {
312 // Ensure the CallOp targets `srcFunc`
313 auto callOpTarget = callOp.getCalleeTarget(this->data.tables);
314 assert(succeeded(callOpTarget));
315 if (callOpTarget->get() != srcFunc) {
316 return WalkResult::advance();
317 }
318
319 // Get the "self" struct parameter from the CallOp and determine which field that struct
320 // was stored in within the caller (i.e. `destFunc`).
321 FieldRefOpInterface selfFieldRefOp = this->getSelfRefField(callOp);
322 if (!selfFieldRefOp) {
323 // Note: error message was already printed within `getSelfRefField()`
324 return WalkResult::interrupt(); // use interrupt to signal failure
325 }
326
327 // Create a clone of the source function (must do the whole function not just the body
328 // region because `inlineCall()` expects the Region to have a parent op) and update field
329 // references to the old struct fields to instead use the new struct fields.
330 FuncDefOp srcFuncClone =
331 FieldRefRewriter::cloneWithFieldRefUpdate(std::make_unique<FieldRefRewriter>(
332 srcFunc, selfFieldRefOp.getComponent(),
333 this->destToSrcToClone.at(this->data.getDef(selfFieldRefOp))
334 ));
335 this->processCloneBeforeInlining(srcFuncClone);
336
337 // Inline the cloned function in place of `callOp`
338 LogicalResult inlineCallRes =
339 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.getBody(), false);
340 if (failed(inlineCallRes)) {
341 callOp.emitError().append("Failed to inline ", srcFunc.getFullyQualifiedName()).report();
342 return WalkResult::interrupt(); // use interrupt to signal failure
343 }
344 srcFuncClone.erase(); // delete what's left after transferring the body elsewhere
345 callOp.erase(); // delete the original CallOp
346 return WalkResult::skip(); // Must skip because the CallOp was erased.
347 };
348
349 auto fieldWriteHandler = [this](FieldWriteOp writeOp) {
350 // Check if the field ref op should be deleted in the end
351 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
352 this->data.toDelete.fieldRefOps.push_back(writeOp);
353 }
354 return WalkResult::advance();
355 };
356
359 auto fieldReadHandler = [this](FieldReadOp readOp) {
360 // Check if the field ref op should be deleted in the end
361 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
362 this->data.toDelete.fieldRefOps.push_back(readOp);
363 }
364 // If the FieldReadOp was replaced/erased, must skip.
365 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
366 ? WalkResult::skip()
367 : WalkResult::advance();
368 };
369
370 WalkResult walkRes = destFunc.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
371 return TypeSwitch<Operation *, WalkResult>(op)
372 .Case<CallOp>(callHandler)
373 .Case<FieldWriteOp>(fieldWriteHandler)
374 .Case<FieldReadOp>(fieldReadHandler)
375 .Default([](Operation *) { return WalkResult::advance(); });
376 });
377
378 return failure(walkRes.wasInterrupted());
379 }
380 };
381
382 class ConstrainImpl : public ImplBase {
383 using ImplBase::ImplBase;
384
385 FieldRefOpInterface getSelfRefField(CallOp callOp) override {
386 // The typical pattern is to read a struct instance from a field and then call "constrain()"
387 // on it. Get the Value passed as the "self" struct to the CallOp and determine which field it
388 // was read from in the current struct (i.e., `destStruct`).
389 FieldRefOpInterface selfFieldRef = getFieldReadThatDefinesSelfValuePassedToConstrain(callOp);
390 if (selfFieldRef &&
391 selfFieldRef.getComponent().getType() == this->data.destStruct.getType()) {
392 return selfFieldRef;
393 }
394 callOp.emitError()
395 .append(
396 "expected \"self\" parameter to \"@", FUNC_NAME_CONSTRAIN,
397 "\" to be passed a value read from a field in the current stuct."
398 )
399 .report();
400 return nullptr;
401 }
402 };
403
404 class ComputeImpl : public ImplBase {
405 using ImplBase::ImplBase;
406
407 FieldRefOpInterface getSelfRefField(CallOp callOp) override {
408 // The typical pattern is to write the return value of "compute()" to a field in
409 // the current struct (i.e., `destStruct`).
410 // It doesn't really make sense (although there is no semantic restriction against it) to just
411 // pass the "compute()" result into another function and never write it to a field since that
412 // leaves no way for the "constrain()" function to call "constrain()" on that result struct.
413 FailureOr<FieldWriteOp> foundWrite =
414 findOpThatStoresSubcmp(callOp.getSelfValueFromCompute(), [&callOp]() {
415 return callOp.emitOpError().append("\"@", FUNC_NAME_COMPUTE, "\" ");
416 });
417 return static_cast<FieldRefOpInterface>(foundWrite.value_or(nullptr));
418 }
419
420 void processCloneBeforeInlining(FuncDefOp func) override {
421 // Within the compute function, find `CreateStructOp` with `srcStruct` type and mark them
422 // for later deletion. The deletion must occur later because these values may still have
423 // uses until ALL callees of a function have been inlined.
424 func.getBody().walk([this](CreateStructOp newStructOp) {
425 if (newStructOp.getType() == this->data.srcStruct.getType()) {
426 this->data.toDelete.newStructOps.push_back(newStructOp);
427 }
428 });
429 }
430 };
431
432 // Find any field(s) in `destStruct` whose type matches `srcStruct` (allowing any parameters, if
433 // applicable). For each such field, clone all fields from `srcStruct` into `destStruct` and cache
434 // the mapping of `destStruct` to `srcStruct` to cloned fields in the return value.
435 DestToSrcToClonedSrcInDest cloneFields() {
436 DestToSrcToClonedSrcInDest destToSrcToClone;
437
438 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
439 StructType srcStructType = srcStruct.getType();
440 for (FieldDefOp destField : destStruct.getFieldDefs()) {
441 if (StructType destFieldType = llvm::dyn_cast<StructType>(destField.getType())) {
442 UnificationMap unifications;
443 if (!structTypesUnify(srcStructType, destFieldType, {}, &unifications)) {
444 continue;
445 }
446 assert(unifications.empty()); // `makePlan()` reports failure earlier
447 // Mark the original `destField` for deletion
448 toDelete.fieldDefs.push_back(destField);
449 // Clone each field from 'srcStruct' into 'destStruct'. Add an entry to `destToSrcToClone`
450 // even if there are no fields in `srcStruct` so its presence can be used as a marker.
451 SrcStructFieldToCloneInDest &srcToClone = destToSrcToClone.getOrInsertDefault(destField);
452 std::vector<FieldDefOp> srcFields = srcStruct.getFieldDefs();
453 if (srcFields.empty()) {
454 continue;
455 }
456 OpBuilder builder(destField);
457 std::string newNameBase =
458 destField.getName().str() + ':' + BuildShortTypeString::from(destFieldType);
459 for (FieldDefOp srcField : srcFields) {
460 DestCloneOfSrcStructField newF = llvm::cast<FieldDefOp>(builder.clone(*srcField));
461 newF.setName(builder.getStringAttr(newNameBase + '+' + newF.getName()));
462 srcToClone[srcField.getSymNameAttr()] = newF;
463 // Also update the cached SymbolTable
464 destStructSymTable.insert(newF);
465 }
466 }
467 }
468 return destToSrcToClone;
469 }
470
472 inline LogicalResult inlineConstrainCall(const DestToSrcToClonedSrcInDest &destToSrcToClone) {
473 return ConstrainImpl(*this, destToSrcToClone)
474 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
475 }
476
478 inline LogicalResult inlineComputeCall(const DestToSrcToClonedSrcInDest &destToSrcToClone) {
479 return ComputeImpl(*this, destToSrcToClone)
480 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
481 }
482
483public:
484 StructInliner(
485 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp from, StructDefOp into
486 )
487 : tables(tbls), toDelete(opsToDelete), srcStruct(from), destStruct(into) {}
488
489 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
490 LLVM_DEBUG(
491 llvm::dbgs() << "[StructInliner] merge " << srcStruct.getSymNameAttr() << " into "
492 << destStruct.getSymNameAttr() << '\n'
493 );
494
495 DestToSrcToClonedSrcInDest destToSrcToClone = cloneFields();
496 if (failed(inlineConstrainCall(destToSrcToClone)) ||
497 failed(inlineComputeCall(destToSrcToClone))) {
498 return failure(); // error already printed within doInlining()
499 }
500 return destToSrcToClone;
501 }
502};
503
507inline void splitFunctionParam(
508 FuncDefOp func, unsigned paramIdx, const SrcStructFieldToCloneInDest &nameToNewField
509) {
510 class Impl : public FunctionTypeConverter {
511 unsigned inputIdx;
512 const SrcStructFieldToCloneInDest &newFields;
513
514 public:
515 Impl(unsigned paramIdx, const SrcStructFieldToCloneInDest &nameToNewField)
516 : inputIdx(paramIdx), newFields(nameToNewField) {}
517
518 protected:
519 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
520 SmallVector<Type> newTypes(origTypes);
521 auto it = newTypes.erase(newTypes.begin() + inputIdx);
522 for (auto [_, newField] : newFields) {
523 newTypes.insert(it, newField.getType());
524 ++it;
525 }
526 return newTypes;
527 }
528 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
529 return SmallVector<Type>(origTypes);
530 }
531 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type>) override {
532 if (origAttrs) {
533 // Replicate the value at `origAttrs[inputIdx]` to have `newFields.size()`
534 SmallVector<Attribute> newAttrs(origAttrs.getValue());
535 newAttrs.insert(newAttrs.begin() + inputIdx, newFields.size() - 1, origAttrs[inputIdx]);
536 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
537 }
538 return nullptr;
539 }
540 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type>) override {
541 return origAttrs;
542 }
543
544 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
545 Value oldStructRef = entryBlock.getArgument(inputIdx);
546
547 // Insert new Block arguments, one per field, following the original one. Keep a map
548 // of field name to the associated block argument for replacing FieldReadOp.
549 llvm::StringMap<BlockArgument> fieldNameToNewArg;
550 Location loc = oldStructRef.getLoc();
551 unsigned idx = inputIdx;
552 for (auto [fieldName, newField] : newFields) {
553 // note: pre-increment so the original to be erased is still at `inputIdx`
554 BlockArgument newArg = entryBlock.insertArgument(++idx, newField.getType(), loc);
555 fieldNameToNewArg[fieldName] = newArg;
556 }
557
558 // Find all field reads from the original Block argument and replace uses of those
559 // reads with the appropriate new Block argument.
560 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
561 if (FieldReadOp readOp = llvm::dyn_cast<FieldReadOp>(oldBlockArgUse.getOwner())) {
562 if (readOp.getComponent() == oldStructRef) {
563 BlockArgument newArg = fieldNameToNewArg.at(readOp.getFieldName());
564 rewriter.replaceAllUsesWith(readOp, newArg);
565 rewriter.eraseOp(readOp);
566 continue;
567 }
568 }
569 // Currently, there's no other way in which a StructType parameter can be used.
570 llvm::errs() << "Unexpected use of " << oldBlockArgUse.get() << " in "
571 << *oldBlockArgUse.getOwner() << '\n';
572 llvm_unreachable("Not yet implemented");
573 }
574
575 // Delete the original Block argument
576 entryBlock.eraseArgument(inputIdx);
577 }
578 };
579 IRRewriter rewriter(func.getContext());
580 Impl(paramIdx, nameToNewField).convert(func, rewriter);
581}
582
583class InlineStructsPass : public llzk::impl::InlineStructsPassBase<InlineStructsPass> {
590 using InliningPlan = SmallVector<std::pair<StructDefOp, SmallVector<StructDefOp>>>;
591
592 static uint64_t complexity(FuncDefOp f) {
593 uint64_t complexity = 0;
594 f.getBody().walk([&complexity](Operation *op) {
595 if (llvm::isa<felt::MulFeltOp>(op)) {
596 ++complexity;
597 } else if (auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
598 complexity += computeEmitEqCardinality(ee.getLhs().getType());
599 } else if (auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
600 // TODO: increment based on dimension sizes in the operands
601 // Pending update to implementation/semantics of EmitContainmentOp.
602 ++complexity;
603 }
604 });
605 return complexity;
606 }
607
608 static FailureOr<FuncDefOp>
609 getIfStructConstrain(const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
610 auto lookupRes = node->lookupSymbol(tables, false);
611 assert(succeeded(lookupRes) && "graph contains node with invalid path");
612 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
613 if (f.isStructConstrain()) {
614 return f;
615 }
616 }
617 return failure();
618 }
619
622 static inline StructDefOp getParentStruct(FuncDefOp func) {
623 assert(func.isStructConstrain()); // pre-condition
624 FailureOr<StructDefOp> currentNodeParentStruct = getParentOfType<StructDefOp>(func);
625 assert(succeeded(currentNodeParentStruct)); // follows from ODS definition
626 return currentNodeParentStruct.value();
627 }
628
630 inline bool exceedsMaxComplexity(uint64_t check) {
631 return maxComplexity > 0 && check > maxComplexity;
632 }
633
636 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
637 // Find CallOp for `successorFunc` within `currentFunc` and check the condition used by
638 // `ConstrainImpl::getSelfRefField()`.
639 //
640 // Implementation Note: There is a possibility that the "self" value is not from a field read.
641 // It could be a parameter to the current/destination function or a global read. Inlining a
642 // struct stored to a global would probably require splitting up the global into multiple, one
643 // for each field in the successor/source struct. That may not be a good idea. The parameter
644 // case could be handled but it will not have a mapping in `destToSrcToClone` in
645 // `getSelfRefField()` and new fields will still need to be added. They can be prefixed with
646 // parameter index since there is no current field name to use as the unique prefix. Handling
647 // that would require refactoring the inlining process a bit.
648 WalkResult res = currentFunc.walk([](CallOp c) {
649 return getFieldReadThatDefinesSelfValuePassedToConstrain(c)
650 ? WalkResult::interrupt() // use interrupt to indicate success
651 : WalkResult::advance();
652 });
653 LLVM_DEBUG({
654 llvm::dbgs() << "[canInline] " << successorFunc.getFullyQualifiedName() << " into "
655 << currentFunc.getFullyQualifiedName() << "? " << res.wasInterrupted() << '\n';
656 });
657 return res.wasInterrupted();
658 }
659
664 inline FailureOr<InliningPlan>
665 makePlan(const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
666 LLVM_DEBUG({
667 llvm::dbgs() << "Running InlineStructsPass with max complexity ";
668 if (maxComplexity == 0) {
669 llvm::dbgs() << "unlimited";
670 } else {
671 llvm::dbgs() << maxComplexity;
672 }
673 llvm::dbgs() << '\n';
674 });
675 InliningPlan retVal;
676 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
677
678 // NOTE: The assumption that the use graph has no cycles allows `complexityMemo` to only
679 // store the result for relevant nodes and assume nodes without a mapped value are `0`. This
680 // must be true of the "compute"/"constrain" function uses and field defs because circuits
681 // must be acyclic. This is likely true to for the symbol use graph is general but if a
682 // counterexample is ever found, the algorithm below must be re-evaluated.
683 assert(!hasCycle(&useGraph));
684
685 // Traverse "constrain" function nodes to compute their complexity and an inlining plan. Use
686 // post-order traversal so the complexity of all successor nodes is computed before computing
687 // the current node's complexity.
688 for (const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
689 LLVM_DEBUG(llvm::dbgs() << "\ncurrentNode = " << currentNode->toString());
690 if (!currentNode->isRealNode()) {
691 continue;
692 }
693 if (currentNode->isStructParam()) {
694 // Try to get the location of the StructDefOp to report an error.
695 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
696 SymbolRefAttr prefix = getPrefixAsSymbolRefAttr(currentNode->getSymbolPath());
697 auto res = lookupSymbolIn<StructDefOp>(tables, prefix, lookupFrom, lookupFrom, false);
698 // If that lookup didn't work for some reason, report at the path root location.
699 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
700 return reportLoc->emitError("Cannot inline structs with parameters.");
701 }
702 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
703 if (failed(currentFuncOpt)) {
704 continue;
705 }
706 FuncDefOp currentFunc = currentFuncOpt.value();
707 uint64_t currentComplexity = complexity(currentFunc);
708 // If the current complexity is already too high, store it and continue.
709 if (exceedsMaxComplexity(currentComplexity)) {
710 complexityMemo[currentNode] = currentComplexity;
711 continue;
712 }
713 // Otherwise, make a plan that adds successor "constrain" functions unless the
714 // complexity becomes too high by adding that successor.
715 SmallVector<StructDefOp> successorsToMerge;
716 for (const SymbolUseGraphNode *successor : currentNode->successorIter()) {
717 LLVM_DEBUG(llvm::dbgs().indent(2) << "successor: " << successor->toString() << '\n');
718 // Note: all "constrain" function nodes will have a value, and all other nodes will not.
719 auto memoResult = complexityMemo.find(successor);
720 if (memoResult == complexityMemo.end()) {
721 continue; // inner loop
722 }
723 uint64_t sComplexity = memoResult->second;
724 assert(
725 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
726 "addition will overflow"
727 );
728 uint64_t potentialComplexity = currentComplexity + sComplexity;
729 if (!exceedsMaxComplexity(potentialComplexity)) {
730 currentComplexity = potentialComplexity;
731 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
732 assert(succeeded(successorFuncOpt)); // follows from the Note above
733 FuncDefOp successorFunc = successorFuncOpt.value();
734 if (canInline(currentFunc, successorFunc)) {
735 successorsToMerge.push_back(getParentStruct(successorFunc));
736 }
737 }
738 }
739 complexityMemo[currentNode] = currentComplexity;
740 if (!successorsToMerge.empty()) {
741 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
742 }
743 }
744 LLVM_DEBUG({
745 llvm::dbgs() << "-----------------------------------------------------------------\n";
746 llvm::dbgs() << "InlineStructsPass plan:\n";
747 for (auto &[caller, callees] : retVal) {
748 llvm::dbgs().indent(2) << "inlining the following into \"" << caller.getSymName() << "\"\n";
749 for (StructDefOp c : callees) {
750 llvm::dbgs().indent(4) << "\"" << c.getSymName() << "\"\n";
751 }
752 }
753 llvm::dbgs() << "-----------------------------------------------------------------\n";
754 });
755 return retVal;
756 }
757
763 static LogicalResult handleRemainingUses(
764 Operation *op, SymbolTableCollection &tables,
765 const DestToSrcToClonedSrcInDest &destToSrcToClone,
766 ArrayRef<FieldRefOpInterface> otherRefsToBeDeleted = {}
767 ) {
768 if (op->use_empty()) {
769 return success(); // safe to erase
770 }
771
772 // Helper function to determine if an Operation is contained in 'otherRefsToBeDeleted'
773 auto opWillBeDeleted = [&otherRefsToBeDeleted](Operation *op) -> bool {
774 return std::find(otherRefsToBeDeleted.begin(), otherRefsToBeDeleted.end(), op) !=
775 otherRefsToBeDeleted.end();
776 };
777
778 LLVM_DEBUG({
779 llvm::dbgs() << "[handleRemainingUses] op: " << *op << '\n';
780 llvm::dbgs() << "[handleRemainingUses] in function: " << op->getParentOfType<FuncDefOp>()
781 << '\n';
782 });
783 for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
784 if (CallOp c = llvm::dyn_cast<CallOp>(use.getOwner())) {
785 LLVM_DEBUG(llvm::dbgs() << "[handleRemainingUses] use in call: " << c << '\n');
786 unsigned argIdx = use.getOperandNumber() - c.getArgOperands().getBeginOperandIndex();
787 LLVM_DEBUG(llvm::dbgs() << "[handleRemainingUses] at index: " << argIdx << '\n');
788
789 auto tgtFuncRes = c.getCalleeTarget(tables);
790 if (failed(tgtFuncRes)) {
791 return op
792 ->emitOpError("as argument to an unknown function is not supported by this pass.")
793 .attachNote(c.getLoc())
794 .append("used by this call");
795 }
796 FuncDefOp tgtFunc = tgtFuncRes->get();
797 LLVM_DEBUG(llvm::dbgs() << "[handleRemainingUses] call target: " << tgtFunc << '\n');
798 if (tgtFunc.isExternal()) {
799 // Those without a body (i.e. external implementation) present a problem because LLZK does
800 // not define a memory layout for the external implementation to interpret the struct.
801 return op
802 ->emitOpError("as argument to a no-body free function is not supported by this pass.")
803 .attachNote(c.getLoc())
804 .append("used by this call");
805 }
806
807 FieldRefOpInterface paramFromField = TypeSwitch<Operation *, FieldRefOpInterface>(op)
808 .Case<FieldReadOp>([](auto p) { return p; })
809 .Case<CreateStructOp>([](auto p) {
810 return findOpThatStoresSubcmp(p, [&p]() { return p.emitOpError(); }).value_or(nullptr);
811 }).Default([](Operation *p) {
812 llvm::errs() << "Encountered unexpected op: "
813 << (p ? p->getName().getStringRef() : "<<null>>") << '\n';
814 llvm_unreachable("Unexpected op kind");
815 return nullptr;
816 });
817 LLVM_DEBUG({
818 llvm::dbgs() << "[handleRemainingUses] field ref op for param: "
819 << (paramFromField ? debug::toStringOne(paramFromField) : "<<null>>")
820 << '\n';
821 });
822 if (!paramFromField) {
823 return failure(); // error already printed within findOpThatStoresSubcmp()
824 }
825 const SrcStructFieldToCloneInDest &newFields =
826 destToSrcToClone.at(getDef(tables, paramFromField));
827 LLVM_DEBUG({
828 llvm::dbgs() << "[handleRemainingUses] fields to split: "
829 << debug::toStringList(newFields) << '\n';
830 });
831
832 // Convert the FuncDefOp side first (to use the easier builder for the new CallOp).
833 splitFunctionParam(tgtFunc, argIdx, newFields);
834 LLVM_DEBUG({
835 llvm::dbgs() << "[handleRemainingUses] UPDATED call target: " << tgtFunc << '\n';
836 llvm::dbgs() << "[handleRemainingUses] UPDATED call target type: "
837 << tgtFunc.getFunctionType() << '\n';
838 });
839
840 // Convert the CallOp side. Add a FieldReadOp for each value from the struct and pass them
841 // individually in place of the struct parameter.
842 {
843 OpBuilder builder(c);
844 SmallVector<Value> splitArgs;
845 // Before the CallOp, insert a read from every new field. These Values will replace the
846 // original argument in the CallOp.
847 Value originalBaseVal = paramFromField.getComponent();
848 for (auto [origName, newFieldRef] : newFields) {
849 splitArgs.push_back(builder.create<FieldReadOp>(
850 c.getLoc(), newFieldRef.getType(), originalBaseVal, newFieldRef.getNameAttr()
851 ));
852 }
853 // Generate the new argument list from the original but replace 'argIdx'
854 SmallVector<Value> newOpArgs(c.getArgOperands());
855 newOpArgs.insert(
856 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
857 );
858 // Create the new CallOp, replace uses of the old with the new, delete the old
859 c.replaceAllUsesWith(builder.create<CallOp>(
860 c.getLoc(), tgtFunc, CallOp::toVectorOfValueRange(c.getMapOperands()),
861 c.getNumDimsPerMapAttr(), newOpArgs
862 ));
863 c.erase();
864 }
865 LLVM_DEBUG({
866 llvm::dbgs() << "[handleRemainingUses] UPDATED function: "
867 << op->getParentOfType<FuncDefOp>() << '\n';
868 });
869 } else {
870 Operation *user = use.getOwner();
871 // Report an error for any user other than some field ref that will be deleted anyway.
872 if (!opWillBeDeleted(user)) {
873 return op->emitOpError()
874 .append(
875 "with use in '", user->getName().getStringRef(),
876 "' is not (currently) supported by this pass."
877 )
878 .attachNote(user->getLoc())
879 .append("used by this call");
880 }
881 }
882 }
883 // Ensure that all users of the 'op' were deleted above, or will be per 'otherRefsToBeDeleted'.
884 if (!op->use_empty()) {
885 for (Operation *user : op->getUsers()) {
886 if (!opWillBeDeleted(user)) {
887 llvm::errs() << "Op has remaining use(s) that could not be removed: " << *op << '\n';
888 llvm_unreachable("Expected all uses to be removed");
889 }
890 }
891 }
892 return success();
893 }
894
895 inline static LogicalResult finalizeStruct(
896 SymbolTableCollection &tables, StructDefOp caller, PendingErasure &&toDelete,
897 DestToSrcToClonedSrcInDest &&destToSrcToClone
898 ) {
899 LLVM_DEBUG({
900 llvm::dbgs() << "[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
901 llvm::dbgs() << caller << '\n';
902 });
903
904 // Compress chains of reads that result after inlining multiple callees.
905 caller.getConstrainFuncOp().walk([&tables, &destToSrcToClone](FieldReadOp readOp) {
906 combineReadChain(readOp, tables, destToSrcToClone);
907 });
908 auto res = caller.getComputeFuncOp().walk([&tables, &destToSrcToClone](FieldReadOp readOp) {
909 combineReadChain(readOp, tables, destToSrcToClone);
910 LogicalResult res = combineNewThenReadChain(readOp, tables, destToSrcToClone);
911 return failed(res) ? WalkResult::interrupt() : WalkResult::advance();
912 });
913 if (res.wasInterrupted()) {
914 return failure(); // error already printed within combineNewThenReadChain()
915 }
916
917 LLVM_DEBUG({
918 llvm::dbgs() << "[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
919 llvm::dbgs() << caller << '\n';
920 llvm::dbgs() << "[finalizeStruct] ops marked for deletion:\n";
921 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
922 llvm::dbgs().indent(2) << op << '\n';
923 }
924 for (CreateStructOp op : toDelete.newStructOps) {
925 llvm::dbgs().indent(2) << op << '\n';
926 }
927 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
928 llvm::dbgs().indent(2) << op << '\n';
929 }
930 });
931
932 // Handle remaining uses of CreateStructOp before deleting anything because this process
933 // needs to be able to find the writes that stores the result of these ops.
934 for (CreateStructOp op : toDelete.newStructOps) {
935 if (failed(handleRemainingUses(op, tables, destToSrcToClone, toDelete.fieldRefOps))) {
936 return failure(); // error already printed within handleRemainingUses()
937 }
938 }
939 // Next, to avoid "still has uses" errors, must erase FieldRefOpInterface before erasing
940 // the CreateStructOp or FieldDefOp.
941 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
942 if (failed(handleRemainingUses(op, tables, destToSrcToClone))) {
943 return failure(); // error already printed within handleRemainingUses()
944 }
945 op.erase();
946 }
947 for (CreateStructOp op : toDelete.newStructOps) {
948 op.erase();
949 }
950 // Finally, erase FieldDefOp via SymbolTable so table itself is updated too.
951 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
952 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
953 assert(op.getParentOp() == caller); // using correct SymbolTable
954 callerSymTab.erase(op);
955 }
956
957 return success();
958 }
959
960public:
961 void runOnOperation() override {
962 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
963 LLVM_DEBUG(useGraph.dumpToDotFile());
964
965 SymbolTableCollection tables;
966 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
967 if (failed(plan)) {
968 signalPassFailure(); // error already printed w/in makePlan()
969 return;
970 }
971
972 for (auto &[caller, callees] : plan.value()) {
973 // Cache operations that should be deleted but must wait until all callees are processed
974 // to ensure that all uses of the values defined by these operations are replaced.
975 PendingErasure toDelete;
976 // Cache old-to-new field mappings across all calleeds inlined for the current struct.
977 DestToSrcToClonedSrcInDest aggregateReplacements;
978 // Inline callees/subcomponents of the current struct
979 for (StructDefOp toInline : callees) {
980 FailureOr<DestToSrcToClonedSrcInDest> res =
981 StructInliner(tables, toDelete, toInline, caller).doInline();
982 if (failed(res)) {
983 signalPassFailure(); // error already printed w/in doInline()
984 return;
985 }
986 // Add current field replacements to the aggregate
987 for (auto &[k, v] : res.value()) {
988 assert(!aggregateReplacements.contains(k) && "duplicate not possible");
989 aggregateReplacements[k] = std::move(v);
990 }
991 }
992 // Complete steps to finalize/cleanup the caller
993 LogicalResult finalizeResult =
994 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
995 if (failed(finalizeResult)) {
996 signalPassFailure(); // error already printed w/in combineNewThenReadChain()
997 return;
998 }
999 }
1000 }
1001};
1002
1003} // namespace
1004
1005std::unique_ptr<mlir::Pass> llzk::createInlineStructsPass() {
1006 return std::make_unique<InlineStructsPass>();
1007};
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::llvm::StringRef getFieldName()
Definition Ops.cpp.inc:1103
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.cpp.inc:836
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:509
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the FieldRefOp.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present.
Definition Ops.cpp:357
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present.
Definition Ops.cpp:353
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:531
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:734
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:759
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:729
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:739
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:314
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:333
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:628
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:632
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
Definition Ops.h.inc:641
::mlir::Region & getBody()
Definition Ops.cpp.inc:848
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:304
std::string toStringOne(const T &value)
Definition Debug.h:176
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:150
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)
Definition GraphUtil.h:17
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener)
A fast walk-based pattern rewrite driver.