LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
LLZKInlineStructsPass.cpp
Go to the documentation of this file.
1//===-- LLZKInlineStructsPass.cpp -------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
19//===----------------------------------------------------------------------===//
34#include <mlir/IR/BuiltinOps.h>
35#include <mlir/Transforms/InliningUtils.h>
36#include <mlir/Transforms/WalkPatternRewriteDriver.h>
38#include <llvm/ADT/PostOrderIterator.h>
39#include <llvm/ADT/SmallVector.h>
40#include <llvm/ADT/StringMap.h>
41#include <llvm/ADT/TypeSwitch.h>
42#include <llvm/Support/Debug.h>
44// Include the generated base pass class definitions.
45namespace llzk {
46// the *DECL* macro is required when a pass has options to declare the option struct
47#define GEN_PASS_DECL_INLINESTRUCTSPASS
48#define GEN_PASS_DEF_INLINESTRUCTSPASS
50} // namespace llzk
52using namespace mlir;
53using namespace llzk;
54using namespace llzk::component;
55using namespace llzk::function;
56
57#define DEBUG_TYPE "llzk-inline-structs"
58
59namespace {
60
61using DestFieldWithSrcStructType = FieldDefOp;
62using DestCloneOfSrcStructField = FieldDefOp;
64/// source field in the destination struct. Uses `std::map` for consistent ordering between multiple
65/// compilations of the same LLZK IR input.
66using SrcStructFieldToCloneInDest = std::map<StringRef, DestCloneOfSrcStructField>;
69using DestToSrcToClonedSrcInDest =
70 DenseMap<DestFieldWithSrcStructType, SrcStructFieldToCloneInDest>;
71
72Value getSelfValue(FuncDefOp f) {
73 if (f.nameIsCompute()) {
74 return f.getSelfValueFromCompute();
75 } else if (f.nameIsConstrain()) {
77 } else {
78 llvm_unreachable("expected \"compute\" or \"constrain\" function");
79 }
80}
81
82inline FieldDefOp getDef(SymbolTableCollection &tables, FieldRefOpInterface fRef) {
83 auto r = fRef.getFieldDefOp(tables);
84 assert(succeeded(r));
85 return r->get();
87
89
91bool combineHelper(
92 FieldReadOp readOp, SymbolTableCollection &tables,
93 const DestToSrcToClonedSrcInDest &destToSrcToClone, FieldRefOpInterface destFieldRefOp
94) {
95 auto srcToClone = destToSrcToClone.find(getDef(tables, destFieldRefOp));
96 if (srcToClone == destToSrcToClone.end()) {
97 return false;
98 }
99 SrcStructFieldToCloneInDest oldToNewFields = srcToClone->second;
100 auto resNewField = oldToNewFields.find(readOp.getFieldName());
101 if (resNewField == oldToNewFields.end()) {
102 return false;
103 }
104
105 // Replace this FieldReadOp with a new one that targets the cloned field.
106 OpBuilder builder(readOp);
107 FieldReadOp newRead = builder.create<FieldReadOp>(
108 readOp.getLoc(), readOp.getType(), destFieldRefOp.getComponent(),
109 resNewField->second.getNameAttr()
110 );
111 readOp.replaceAllUsesWith(newRead.getOperation());
112 readOp.erase(); // delete the original FieldReadOp
113 return true;
114}
115
129bool combineReadChain(
130 FieldReadOp readOp, SymbolTableCollection &tables,
131 const DestToSrcToClonedSrcInDest &destToSrcToClone
132) {
133 FieldReadOp readThatDefinesBaseComponent =
134 llvm::dyn_cast_if_present<FieldReadOp>(readOp.getComponent().getDefiningOp());
135 if (!readThatDefinesBaseComponent) {
136 return false;
137 }
138 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
139}
140
143FailureOr<FieldWriteOp>
144findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
145 FieldWriteOp foundWrite = nullptr;
146 for (Operation *user : writtenValue.getUsers()) {
147 if (FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(user)) {
148 // Find the write op that stores the created value
149 if (writeOp.getVal() == writtenValue) {
150 if (foundWrite) {
151 // Note: There is no reason for a subcomponent to be stored to more than one field.
152 auto diag = emitError().append("result should not be written to more than one field.");
153 diag.attachNote(foundWrite.getLoc()).append("written here");
154 diag.attachNote(writeOp.getLoc()).append("written here");
155 return diag;
156 } else {
157 foundWrite = writeOp;
158 }
159 }
160 }
161 }
162 if (!foundWrite) {
163 // Note: There is no reason to construct a subcomponent and not store it to a field.
164 return emitError().append("result should be written to a field.");
165 }
166 return foundWrite;
167}
168
185LogicalResult combineNewThenReadChain(
186 FieldReadOp readOp, SymbolTableCollection &tables,
187 const DestToSrcToClonedSrcInDest &destToSrcToClone
188) {
189 CreateStructOp createThatDefinesBaseComponent =
190 llvm::dyn_cast_if_present<CreateStructOp>(readOp.getComponent().getDefiningOp());
191 if (!createThatDefinesBaseComponent) {
192 return success(); // No error. The pattern simply doesn't match.
193 }
194 FailureOr<FieldWriteOp> foundWrite =
195 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
196 return createThatDefinesBaseComponent.emitOpError();
197 });
198 if (failed(foundWrite)) {
199 return failure(); // error already printed within findOpThatStoresSubcmp()
200 }
201 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
202}
203
204inline FieldReadOp getFieldReadThatDefinesSelfValuePassedToConstrain(CallOp callOp) {
205 Value selfArgFromCall = callOp.getSelfValueFromConstrain();
206 return llvm::dyn_cast_if_present<FieldReadOp>(selfArgFromCall.getDefiningOp());
207}
208
211struct PendingErasure {
212 SmallVector<FieldRefOpInterface> fieldRefOps;
213 SmallVector<CreateStructOp> newStructOps;
214 SmallVector<DestFieldWithSrcStructType> fieldDefs;
215};
216
217class StructInliner {
218 SymbolTableCollection &tables;
219 PendingErasure &toDelete;
221 StructDefOp srcStruct;
223 StructDefOp destStruct;
224
225 inline FieldDefOp getDef(FieldRefOpInterface fRef) const { return ::getDef(tables, fRef); }
226
227 // Update field read/write ops that target the "self" value of the FuncDefOp plus some key in
228 // `oldToNewFieldDef` to instead target the new base Value provided to the constructor plus the
229 // mapped Value from `oldToNewFieldDef`.
230 // Example:
231 // old: %1 = struct.readf %0[@f1] : <@Component1A>, !felt.type
232 // new: %1 = struct.readf %self[@"f2:!s<@Component1A>+f1"] : <@Component1B>, !felt.type
233 class FieldRefRewriter final : public OpInterfaceRewritePattern<FieldRefOpInterface> {
236 FuncDefOp funcRef;
238 Value oldBaseVal;
240 Value newBaseVal;
241 const SrcStructFieldToCloneInDest &oldToNewFields;
242
243 public:
244 FieldRefRewriter(
245 FuncDefOp originalFunc, Value newRefBase,
246 const SrcStructFieldToCloneInDest &oldToNewFieldDef
247 )
248 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
249 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewFields(oldToNewFieldDef) {}
250
251 LogicalResult match(FieldRefOpInterface op) const final {
252 assert(oldBaseVal); // ensure it's used via `cloneWithFieldRefUpdate()` only
253 // Check if the FieldRef accesses a field of "self" within the `oldToNewFields` map.
254 // Per `cloneWithFieldRefUpdate()`, `oldBaseVal` is the "self" value of `funcRef` so
255 // check for a match there and then check that the referenced field name is in the map.
256 return success(op.getComponent() == oldBaseVal && oldToNewFields.contains(op.getFieldName()));
257 }
258
259 void rewrite(FieldRefOpInterface op, PatternRewriter &rewriter) const final {
260 rewriter.modifyOpInPlace(op, [this, &op]() {
261 DestCloneOfSrcStructField newF = oldToNewFields.at(op.getFieldName());
262 op.setFieldName(newF.getSymName());
263 op.getComponentMutable().set(this->newBaseVal);
264 });
265 }
266
269 static FuncDefOp cloneWithFieldRefUpdate(std::unique_ptr<FieldRefRewriter> thisPat) {
270 IRMapping mapper;
271 FuncDefOp srcFuncClone = thisPat->funcRef.clone(mapper);
272 // Update some data in the `FieldRefRewriter` instance before moving it.
273 thisPat->funcRef = srcFuncClone;
274 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
275 // Run the rewriter to replace read/write ops
276 MLIRContext *ctx = thisPat->getContext();
277 RewritePatternSet patterns(ctx, std::move(thisPat));
278 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
279
280 return srcFuncClone;
281 }
282 };
283
285 class ImplBase {
286 protected:
287 const StructInliner &data;
288 const DestToSrcToClonedSrcInDest &destToSrcToClone;
289
292 virtual FieldRefOpInterface getSelfRefField(CallOp callOp) = 0;
293 virtual void processCloneBeforeInlining(FuncDefOp func) {}
294 virtual ~ImplBase() = default;
295
296 public:
297 ImplBase(const StructInliner &inliner, const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
298 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
299
300 LogicalResult doInlining(FuncDefOp srcFunc, FuncDefOp destFunc) {
301 LLVM_DEBUG({
302 llvm::dbgs() << "[doInlining] SOURCE FUNCTION:\n";
303 srcFunc.dump();
304 llvm::dbgs() << "[doInlining] DESTINATION FUNCTION:\n";
305 destFunc.dump();
306 });
307
308 InlinerInterface inliner(destFunc.getContext());
309
311 auto callHandler = [this, &inliner, &srcFunc](CallOp callOp) {
312 // Ensure the CallOp targets `srcFunc`
313 auto callOpTarget = callOp.getCalleeTarget(this->data.tables);
314 assert(succeeded(callOpTarget));
315 if (callOpTarget->get() != srcFunc) {
316 return WalkResult::advance();
317 }
318
319 // Get the "self" struct parameter from the CallOp and determine which field that struct
320 // was stored in within the caller (i.e. `destFunc`).
321 FieldRefOpInterface selfFieldRefOp = this->getSelfRefField(callOp);
322 if (!selfFieldRefOp) {
323 // Note: error message was already printed within `getSelfRefField()`
324 return WalkResult::interrupt(); // use interrupt to signal failure
325 }
326
327 // Create a clone of the source function (must do the whole function not just the body
328 // region because `inlineCall()` expects the Region to have a parent op) and update field
329 // references to the old struct fields to instead use the new struct fields.
330 FuncDefOp srcFuncClone = FieldRefRewriter::cloneWithFieldRefUpdate(
331 std::make_unique<FieldRefRewriter>(
332 srcFunc, selfFieldRefOp.getComponent(),
333 this->destToSrcToClone.at(this->data.getDef(selfFieldRefOp))
334 )
335 );
336 this->processCloneBeforeInlining(srcFuncClone);
337
338 // Inline the cloned function in place of `callOp`
339 LogicalResult inlineCallRes =
340 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.getBody(), false);
341 if (failed(inlineCallRes)) {
342 callOp.emitError().append("Failed to inline ", srcFunc.getFullyQualifiedName()).report();
343 return WalkResult::interrupt(); // use interrupt to signal failure
344 }
345 srcFuncClone.erase(); // delete what's left after transferring the body elsewhere
346 callOp.erase(); // delete the original CallOp
347 return WalkResult::skip(); // Must skip because the CallOp was erased.
348 };
349
350 auto fieldWriteHandler = [this](FieldWriteOp writeOp) {
351 // Check if the field ref op should be deleted in the end
352 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
353 this->data.toDelete.fieldRefOps.push_back(writeOp);
354 }
355 return WalkResult::advance();
356 };
357
360 auto fieldReadHandler = [this](FieldReadOp readOp) {
361 // Check if the field ref op should be deleted in the end
362 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
363 this->data.toDelete.fieldRefOps.push_back(readOp);
364 }
365 // If the FieldReadOp was replaced/erased, must skip.
366 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
367 ? WalkResult::skip()
368 : WalkResult::advance();
369 };
370
371 WalkResult walkRes = destFunc.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
372 return TypeSwitch<Operation *, WalkResult>(op)
373 .Case<CallOp>(callHandler)
374 .Case<FieldWriteOp>(fieldWriteHandler)
375 .Case<FieldReadOp>(fieldReadHandler)
376 .Default([](Operation *) { return WalkResult::advance(); });
377 });
378
379 return failure(walkRes.wasInterrupted());
380 }
381 };
382
383 class ConstrainImpl : public ImplBase {
384 using ImplBase::ImplBase;
385
386 FieldRefOpInterface getSelfRefField(CallOp callOp) override {
387 // The typical pattern is to read a struct instance from a field and then call "constrain()"
388 // on it. Get the Value passed as the "self" struct to the CallOp and determine which field it
389 // was read from in the current struct (i.e., `destStruct`).
390 FieldRefOpInterface selfFieldRef = getFieldReadThatDefinesSelfValuePassedToConstrain(callOp);
391 if (selfFieldRef &&
392 selfFieldRef.getComponent().getType() == this->data.destStruct.getType()) {
393 return selfFieldRef;
394 }
395 callOp.emitError()
396 .append(
397 "expected \"self\" parameter to \"@", FUNC_NAME_CONSTRAIN,
398 "\" to be passed a value read from a field in the current stuct."
399 )
400 .report();
401 return nullptr;
402 }
403 };
404
405 class ComputeImpl : public ImplBase {
406 using ImplBase::ImplBase;
407
408 FieldRefOpInterface getSelfRefField(CallOp callOp) override {
409 // The typical pattern is to write the return value of "compute()" to a field in
410 // the current struct (i.e., `destStruct`).
411 // It doesn't really make sense (although there is no semantic restriction against it) to just
412 // pass the "compute()" result into another function and never write it to a field since that
413 // leaves no way for the "constrain()" function to call "constrain()" on that result struct.
414 FailureOr<FieldWriteOp> foundWrite =
415 findOpThatStoresSubcmp(callOp.getSelfValueFromCompute(), [&callOp]() {
416 return callOp.emitOpError().append("\"@", FUNC_NAME_COMPUTE, "\" ");
417 });
418 return static_cast<FieldRefOpInterface>(foundWrite.value_or(nullptr));
419 }
420
421 void processCloneBeforeInlining(FuncDefOp func) override {
422 // Within the compute function, find `CreateStructOp` with `srcStruct` type and mark them
423 // for later deletion. The deletion must occur later because these values may still have
424 // uses until ALL callees of a function have been inlined.
425 func.getBody().walk([this](CreateStructOp newStructOp) {
426 if (newStructOp.getType() == this->data.srcStruct.getType()) {
427 this->data.toDelete.newStructOps.push_back(newStructOp);
428 }
429 });
430 }
431 };
432
433 // Find any field(s) in `destStruct` whose type matches `srcStruct` (allowing any parameters, if
434 // applicable). For each such field, clone all fields from `srcStruct` into `destStruct` and cache
435 // the mapping of `destStruct` to `srcStruct` to cloned fields in the return value.
436 DestToSrcToClonedSrcInDest cloneFields() {
437 DestToSrcToClonedSrcInDest destToSrcToClone;
438
439 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
440 StructType srcStructType = srcStruct.getType();
441 for (FieldDefOp destField : destStruct.getFieldDefs()) {
442 if (StructType destFieldType = llvm::dyn_cast<StructType>(destField.getType())) {
443 UnificationMap unifications;
444 if (!structTypesUnify(srcStructType, destFieldType, {}, &unifications)) {
445 continue;
446 }
447 assert(unifications.empty()); // `makePlan()` reports failure earlier
448 // Mark the original `destField` for deletion
449 toDelete.fieldDefs.push_back(destField);
450 // Clone each field from 'srcStruct' into 'destStruct'. Add an entry to `destToSrcToClone`
451 // even if there are no fields in `srcStruct` so its presence can be used as a marker.
452 SrcStructFieldToCloneInDest &srcToClone = destToSrcToClone[destField];
453 std::vector<FieldDefOp> srcFields = srcStruct.getFieldDefs();
454 if (srcFields.empty()) {
455 continue;
456 }
457 OpBuilder builder(destField);
458 std::string newNameBase =
459 destField.getName().str() + ':' + BuildShortTypeString::from(destFieldType);
460 for (FieldDefOp srcField : srcFields) {
461 DestCloneOfSrcStructField newF = llvm::cast<FieldDefOp>(builder.clone(*srcField));
462 newF.setName(builder.getStringAttr(newNameBase + '+' + newF.getName()));
463 srcToClone[srcField.getSymNameAttr()] = newF;
464 // Also update the cached SymbolTable
465 destStructSymTable.insert(newF);
466 }
467 }
468 }
469 return destToSrcToClone;
470 }
471
473 inline LogicalResult inlineConstrainCall(const DestToSrcToClonedSrcInDest &destToSrcToClone) {
474 return ConstrainImpl(*this, destToSrcToClone)
475 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
476 }
477
479 inline LogicalResult inlineComputeCall(const DestToSrcToClonedSrcInDest &destToSrcToClone) {
480 return ComputeImpl(*this, destToSrcToClone)
481 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
482 }
483
484public:
485 StructInliner(
486 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp from, StructDefOp into
487 )
488 : tables(tbls), toDelete(opsToDelete), srcStruct(from), destStruct(into) {}
489
490 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
491 LLVM_DEBUG(
492 llvm::dbgs() << "[StructInliner] merge " << srcStruct.getSymNameAttr() << " into "
493 << destStruct.getSymNameAttr() << '\n'
494 );
495
496 DestToSrcToClonedSrcInDest destToSrcToClone = cloneFields();
497 if (failed(inlineConstrainCall(destToSrcToClone)) ||
498 failed(inlineComputeCall(destToSrcToClone))) {
499 return failure(); // error already printed within doInlining()
500 }
501 return destToSrcToClone;
502 }
503};
504
508inline void splitFunctionParam(
509 FuncDefOp func, unsigned paramIdx, const SrcStructFieldToCloneInDest &nameToNewField
510) {
511 class Impl : public FunctionTypeConverter {
512 unsigned inputIdx;
513 const SrcStructFieldToCloneInDest &newFields;
514
515 public:
516 Impl(unsigned paramIdx, const SrcStructFieldToCloneInDest &nameToNewField)
517 : inputIdx(paramIdx), newFields(nameToNewField) {}
518
519 protected:
520 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
521 SmallVector<Type> newTypes(origTypes);
522 auto it = newTypes.erase(newTypes.begin() + inputIdx);
523 for (auto [_, newField] : newFields) {
524 newTypes.insert(it, newField.getType());
525 ++it;
526 }
527 return newTypes;
528 }
529 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
530 return SmallVector<Type>(origTypes);
531 }
532 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type>) override {
533 if (origAttrs) {
534 // Replicate the value at `origAttrs[inputIdx]` to have `newFields.size()`
535 SmallVector<Attribute> newAttrs(origAttrs.getValue());
536 newAttrs.insert(newAttrs.begin() + inputIdx, newFields.size() - 1, origAttrs[inputIdx]);
537 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
538 }
539 return nullptr;
540 }
541 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type>) override {
542 return origAttrs;
543 }
544
545 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
546 Value oldStructRef = entryBlock.getArgument(inputIdx);
547
548 // Insert new Block arguments, one per field, following the original one. Keep a map
549 // of field name to the associated block argument for replacing FieldReadOp.
550 llvm::StringMap<BlockArgument> fieldNameToNewArg;
551 Location loc = oldStructRef.getLoc();
552 unsigned idx = inputIdx;
553 for (auto [fieldName, newField] : newFields) {
554 // note: pre-increment so the original to be erased is still at `inputIdx`
555 BlockArgument newArg = entryBlock.insertArgument(++idx, newField.getType(), loc);
556 fieldNameToNewArg[fieldName] = newArg;
557 }
558
559 // Find all field reads from the original Block argument and replace uses of those
560 // reads with the appropriate new Block argument.
561 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
562 if (FieldReadOp readOp = llvm::dyn_cast<FieldReadOp>(oldBlockArgUse.getOwner())) {
563 if (readOp.getComponent() == oldStructRef) {
564 BlockArgument newArg = fieldNameToNewArg.at(readOp.getFieldName());
565 rewriter.replaceAllUsesWith(readOp, newArg);
566 rewriter.eraseOp(readOp);
567 continue;
568 }
569 }
570 // Currently, there's no other way in which a StructType parameter can be used.
571 llvm::errs() << "Unexpected use of " << oldBlockArgUse.get() << " in "
572 << *oldBlockArgUse.getOwner() << '\n';
573 llvm_unreachable("Not yet implemented");
574 }
575
576 // Delete the original Block argument
577 entryBlock.eraseArgument(inputIdx);
578 }
579 };
580 IRRewriter rewriter(func.getContext());
581 Impl(paramIdx, nameToNewField).convert(func, rewriter);
582}
583
584class InlineStructsPass : public llzk::impl::InlineStructsPassBase<InlineStructsPass> {
591 using InliningPlan = SmallVector<std::pair<StructDefOp, SmallVector<StructDefOp>>>;
592
593 static uint64_t complexity(FuncDefOp f) {
594 uint64_t complexity = 0;
595 f.getBody().walk([&complexity](Operation *op) {
596 if (llvm::isa<felt::MulFeltOp>(op)) {
597 ++complexity;
598 } else if (auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
599 complexity += computeEmitEqCardinality(ee.getLhs().getType());
600 } else if (auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
601 // TODO: increment based on dimension sizes in the operands
602 // Pending update to implementation/semantics of EmitContainmentOp.
603 ++complexity;
604 }
605 });
606 return complexity;
607 }
608
609 static FailureOr<FuncDefOp>
610 getIfStructConstrain(const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
611 auto lookupRes = node->lookupSymbol(tables, false);
612 assert(succeeded(lookupRes) && "graph contains node with invalid path");
613 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
614 if (f.isStructConstrain()) {
615 return f;
616 }
617 }
618 return failure();
619 }
620
623 static inline StructDefOp getParentStruct(FuncDefOp func) {
624 assert(func.isStructConstrain()); // pre-condition
625 FailureOr<StructDefOp> currentNodeParentStruct = getParentOfType<StructDefOp>(func);
626 assert(succeeded(currentNodeParentStruct)); // follows from ODS definition
627 return currentNodeParentStruct.value();
628 }
629
631 inline bool exceedsMaxComplexity(uint64_t check) {
632 return maxComplexity > 0 && check > maxComplexity;
633 }
634
637 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
638 // Find CallOp for `successorFunc` within `currentFunc` and check the condition used by
639 // `ConstrainImpl::getSelfRefField()`.
640 //
641 // Implementation Note: There is a possibility that the "self" value is not from a field read.
642 // It could be a parameter to the current/destination function or a global read. Inlining a
643 // struct stored to a global would probably require splitting up the global into multiple, one
644 // for each field in the successor/source struct. That may not be a good idea. The parameter
645 // case could be handled but it will not have a mapping in `destToSrcToClone` in
646 // `getSelfRefField()` and new fields will still need to be added. They can be prefixed with
647 // parameter index since there is no current field name to use as the unique prefix. Handling
648 // that would require refactoring the inlining process a bit.
649 WalkResult res = currentFunc.walk([](CallOp c) {
650 return getFieldReadThatDefinesSelfValuePassedToConstrain(c)
651 ? WalkResult::interrupt() // use interrupt to indicate success
652 : WalkResult::advance();
653 });
654 LLVM_DEBUG({
655 llvm::dbgs() << "[canInline] " << successorFunc.getFullyQualifiedName() << " into "
656 << currentFunc.getFullyQualifiedName() << "? " << res.wasInterrupted() << '\n';
657 });
658 return res.wasInterrupted();
659 }
660
665 inline FailureOr<InliningPlan>
666 makePlan(const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
667 LLVM_DEBUG({
668 llvm::dbgs() << "Running InlineStructsPass with max complexity ";
669 if (maxComplexity == 0) {
670 llvm::dbgs() << "unlimited";
671 } else {
672 llvm::dbgs() << maxComplexity;
673 }
674 llvm::dbgs() << '\n';
675 });
676 InliningPlan retVal;
677 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
678
679 // NOTE: The assumption that the use graph has no cycles allows `complexityMemo` to only
680 // store the result for relevant nodes and assume nodes without a mapped value are `0`. This
681 // must be true of the "compute"/"constrain" function uses and field defs because circuits
682 // must be acyclic. This is likely true to for the symbol use graph is general but if a
683 // counterexample is ever found, the algorithm below must be re-evaluated.
684 assert(!hasCycle(&useGraph));
685
686 // Traverse "constrain" function nodes to compute their complexity and an inlining plan. Use
687 // post-order traversal so the complexity of all successor nodes is computed before computing
688 // the current node's complexity.
689 for (const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
690 LLVM_DEBUG(llvm::dbgs() << "\ncurrentNode = " << currentNode->toString());
691 if (!currentNode->isRealNode()) {
692 continue;
693 }
694 if (currentNode->isStructParam()) {
695 // Try to get the location of the StructDefOp to report an error.
696 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
697 SymbolRefAttr prefix = getPrefixAsSymbolRefAttr(currentNode->getSymbolPath());
698 auto res = lookupSymbolIn<StructDefOp>(tables, prefix, lookupFrom, lookupFrom, false);
699 // If that lookup didn't work for some reason, report at the path root location.
700 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
701 return reportLoc->emitError("Cannot inline structs with parameters.");
702 }
703 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
704 if (failed(currentFuncOpt)) {
705 continue;
706 }
707 FuncDefOp currentFunc = currentFuncOpt.value();
708 uint64_t currentComplexity = complexity(currentFunc);
709 // If the current complexity is already too high, store it and continue.
710 if (exceedsMaxComplexity(currentComplexity)) {
711 complexityMemo[currentNode] = currentComplexity;
712 continue;
713 }
714 // Otherwise, make a plan that adds successor "constrain" functions unless the
715 // complexity becomes too high by adding that successor.
716 SmallVector<StructDefOp> successorsToMerge;
717 for (const SymbolUseGraphNode *successor : currentNode->successorIter()) {
718 LLVM_DEBUG(llvm::dbgs().indent(2) << "successor: " << successor->toString() << '\n');
719 // Note: all "constrain" function nodes will have a value, and all other nodes will not.
720 auto memoResult = complexityMemo.find(successor);
721 if (memoResult == complexityMemo.end()) {
722 continue; // inner loop
723 }
724 uint64_t sComplexity = memoResult->second;
725 assert(
726 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
727 "addition will overflow"
728 );
729 uint64_t potentialComplexity = currentComplexity + sComplexity;
730 if (!exceedsMaxComplexity(potentialComplexity)) {
731 currentComplexity = potentialComplexity;
732 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
733 assert(succeeded(successorFuncOpt)); // follows from the Note above
734 FuncDefOp successorFunc = successorFuncOpt.value();
735 if (canInline(currentFunc, successorFunc)) {
736 successorsToMerge.push_back(getParentStruct(successorFunc));
737 }
738 }
739 }
740 complexityMemo[currentNode] = currentComplexity;
741 if (!successorsToMerge.empty()) {
742 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
743 }
744 }
745 LLVM_DEBUG({
746 llvm::dbgs() << "-----------------------------------------------------------------\n";
747 llvm::dbgs() << "InlineStructsPass plan:\n";
748 for (auto &[caller, callees] : retVal) {
749 llvm::dbgs().indent(2) << "inlining the following into \"" << caller.getSymName() << "\"\n";
750 for (StructDefOp c : callees) {
751 llvm::dbgs().indent(4) << "\"" << c.getSymName() << "\"\n";
752 }
753 }
754 llvm::dbgs() << "-----------------------------------------------------------------\n";
755 });
756 return retVal;
757 }
758
764 static LogicalResult handleRemainingUses(
765 Operation *op, SymbolTableCollection &tables,
766 const DestToSrcToClonedSrcInDest &destToSrcToClone,
767 ArrayRef<FieldRefOpInterface> otherRefsToBeDeleted = {}
768 ) {
769 if (op->use_empty()) {
770 return success(); // safe to erase
771 }
772
773 // Helper function to determine if an Operation is contained in 'otherRefsToBeDeleted'
774 auto opWillBeDeleted = [&otherRefsToBeDeleted](Operation *otherOp) -> bool {
775 return std::find(otherRefsToBeDeleted.begin(), otherRefsToBeDeleted.end(), otherOp) !=
776 otherRefsToBeDeleted.end();
777 };
778
779 LLVM_DEBUG({
780 llvm::dbgs() << "[handleRemainingUses] op: " << *op << '\n';
781 llvm::dbgs() << "[handleRemainingUses] in function: " << op->getParentOfType<FuncDefOp>()
782 << '\n';
783 });
784 for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
785 if (CallOp c = llvm::dyn_cast<CallOp>(use.getOwner())) {
786 LLVM_DEBUG(llvm::dbgs() << "[handleRemainingUses] use in call: " << c << '\n');
787 unsigned argIdx = use.getOperandNumber() - c.getArgOperands().getBeginOperandIndex();
788 LLVM_DEBUG(llvm::dbgs() << "[handleRemainingUses] at index: " << argIdx << '\n');
789
790 auto tgtFuncRes = c.getCalleeTarget(tables);
791 if (failed(tgtFuncRes)) {
792 return op
793 ->emitOpError("as argument to an unknown function is not supported by this pass.")
794 .attachNote(c.getLoc())
795 .append("used by this call");
796 }
797 FuncDefOp tgtFunc = tgtFuncRes->get();
798 LLVM_DEBUG(llvm::dbgs() << "[handleRemainingUses] call target: " << tgtFunc << '\n');
799 if (tgtFunc.isExternal()) {
800 // Those without a body (i.e. external implementation) present a problem because LLZK does
801 // not define a memory layout for the external implementation to interpret the struct.
802 return op
803 ->emitOpError("as argument to a no-body free function is not supported by this pass.")
804 .attachNote(c.getLoc())
805 .append("used by this call");
806 }
807
808 FieldRefOpInterface paramFromField = TypeSwitch<Operation *, FieldRefOpInterface>(op)
809 .Case<FieldReadOp>([](auto p) { return p; })
810 .Case<CreateStructOp>([](auto p) {
811 return findOpThatStoresSubcmp(p, [&p]() { return p.emitOpError(); }).value_or(nullptr);
812 }).Default([](Operation *p) {
813 llvm::errs() << "Encountered unexpected op: "
814 << (p ? p->getName().getStringRef() : "<<null>>") << '\n';
815 llvm_unreachable("Unexpected op kind");
816 return nullptr;
817 });
818 LLVM_DEBUG({
819 llvm::dbgs() << "[handleRemainingUses] field ref op for param: "
820 << (paramFromField ? debug::toStringOne(paramFromField) : "<<null>>")
821 << '\n';
822 });
823 if (!paramFromField) {
824 return failure(); // error already printed within findOpThatStoresSubcmp()
825 }
826 const SrcStructFieldToCloneInDest &newFields =
827 destToSrcToClone.at(getDef(tables, paramFromField));
828 LLVM_DEBUG({
829 llvm::dbgs() << "[handleRemainingUses] fields to split: "
830 << debug::toStringList(newFields) << '\n';
831 });
832
833 // Convert the FuncDefOp side first (to use the easier builder for the new CallOp).
834 splitFunctionParam(tgtFunc, argIdx, newFields);
835 LLVM_DEBUG({
836 llvm::dbgs() << "[handleRemainingUses] UPDATED call target: " << tgtFunc << '\n';
837 llvm::dbgs() << "[handleRemainingUses] UPDATED call target type: "
838 << tgtFunc.getFunctionType() << '\n';
839 });
840
841 // Convert the CallOp side. Add a FieldReadOp for each value from the struct and pass them
842 // individually in place of the struct parameter.
843 {
844 OpBuilder builder(c);
845 SmallVector<Value> splitArgs;
846 // Before the CallOp, insert a read from every new field. These Values will replace the
847 // original argument in the CallOp.
848 Value originalBaseVal = paramFromField.getComponent();
849 for (auto [origName, newFieldRef] : newFields) {
850 splitArgs.push_back(builder.create<FieldReadOp>(
851 c.getLoc(), newFieldRef.getType(), originalBaseVal, newFieldRef.getNameAttr()
852 ));
853 }
854 // Generate the new argument list from the original but replace 'argIdx'
855 SmallVector<Value> newOpArgs(c.getArgOperands());
856 newOpArgs.insert(
857 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
858 );
859 // Create the new CallOp, replace uses of the old with the new, delete the old
860 c.replaceAllUsesWith(builder.create<CallOp>(
861 c.getLoc(), tgtFunc, CallOp::toVectorOfValueRange(c.getMapOperands()),
862 c.getNumDimsPerMapAttr(), newOpArgs
863 ));
864 c.erase();
865 }
866 LLVM_DEBUG({
867 llvm::dbgs() << "[handleRemainingUses] UPDATED function: "
868 << op->getParentOfType<FuncDefOp>() << '\n';
869 });
870 } else {
871 Operation *user = use.getOwner();
872 // Report an error for any user other than some field ref that will be deleted anyway.
873 if (!opWillBeDeleted(user)) {
874 return op->emitOpError()
875 .append(
876 "with use in '", user->getName().getStringRef(),
877 "' is not (currently) supported by this pass."
878 )
879 .attachNote(user->getLoc())
880 .append("used by this call");
881 }
882 }
883 }
884 // Ensure that all users of the 'op' were deleted above, or will be per 'otherRefsToBeDeleted'.
885 if (!op->use_empty()) {
886 for (Operation *user : op->getUsers()) {
887 if (!opWillBeDeleted(user)) {
888 llvm::errs() << "Op has remaining use(s) that could not be removed: " << *op << '\n';
889 llvm_unreachable("Expected all uses to be removed");
890 }
891 }
892 }
893 return success();
894 }
895
896 inline static LogicalResult finalizeStruct(
897 SymbolTableCollection &tables, StructDefOp caller, PendingErasure &&toDelete,
898 DestToSrcToClonedSrcInDest &&destToSrcToClone
899 ) {
900 LLVM_DEBUG({
901 llvm::dbgs() << "[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
902 llvm::dbgs() << caller << '\n';
903 });
904
905 // Compress chains of reads that result after inlining multiple callees.
906 caller.getConstrainFuncOp().walk([&tables, &destToSrcToClone](FieldReadOp readOp) {
907 combineReadChain(readOp, tables, destToSrcToClone);
908 });
909 auto res = caller.getComputeFuncOp().walk([&tables, &destToSrcToClone](FieldReadOp readOp) {
910 combineReadChain(readOp, tables, destToSrcToClone);
911 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
912 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
913 });
914 if (res.wasInterrupted()) {
915 return failure(); // error already printed within combineNewThenReadChain()
916 }
917
918 LLVM_DEBUG({
919 llvm::dbgs() << "[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
920 llvm::dbgs() << caller << '\n';
921 llvm::dbgs() << "[finalizeStruct] ops marked for deletion:\n";
922 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
923 llvm::dbgs().indent(2) << op << '\n';
924 }
925 for (CreateStructOp op : toDelete.newStructOps) {
926 llvm::dbgs().indent(2) << op << '\n';
927 }
928 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
929 llvm::dbgs().indent(2) << op << '\n';
930 }
931 });
932
933 // Handle remaining uses of CreateStructOp before deleting anything because this process
934 // needs to be able to find the writes that stores the result of these ops.
935 for (CreateStructOp op : toDelete.newStructOps) {
936 if (failed(handleRemainingUses(op, tables, destToSrcToClone, toDelete.fieldRefOps))) {
937 return failure(); // error already printed within handleRemainingUses()
938 }
939 }
940 // Next, to avoid "still has uses" errors, must erase FieldRefOpInterface before erasing
941 // the CreateStructOp or FieldDefOp.
942 for (FieldRefOpInterface op : toDelete.fieldRefOps) {
943 if (failed(handleRemainingUses(op, tables, destToSrcToClone))) {
944 return failure(); // error already printed within handleRemainingUses()
945 }
946 op.erase();
947 }
948 for (CreateStructOp op : toDelete.newStructOps) {
949 op.erase();
950 }
951 // Finally, erase FieldDefOp via SymbolTable so table itself is updated too.
952 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
953 for (DestFieldWithSrcStructType op : toDelete.fieldDefs) {
954 assert(op.getParentOp() == caller); // using correct SymbolTable
955 callerSymTab.erase(op);
956 }
957
958 return success();
959 }
960
961public:
962 void runOnOperation() override {
963 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
964 LLVM_DEBUG(useGraph.dumpToDotFile());
965
966 SymbolTableCollection tables;
967 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
968 if (failed(plan)) {
969 signalPassFailure(); // error already printed w/in makePlan()
970 return;
971 }
972
973 for (auto &[caller, callees] : plan.value()) {
974 // Cache operations that should be deleted but must wait until all callees are processed
975 // to ensure that all uses of the values defined by these operations are replaced.
976 PendingErasure toDelete;
977 // Cache old-to-new field mappings across all calleeds inlined for the current struct.
978 DestToSrcToClonedSrcInDest aggregateReplacements;
979 // Inline callees/subcomponents of the current struct
980 for (StructDefOp toInline : callees) {
981 FailureOr<DestToSrcToClonedSrcInDest> res =
982 StructInliner(tables, toDelete, toInline, caller).doInline();
983 if (failed(res)) {
984 signalPassFailure(); // error already printed w/in doInline()
985 return;
986 }
987 // Add current field replacements to the aggregate
988 for (auto &[k, v] : res.value()) {
989 assert(!aggregateReplacements.contains(k) && "duplicate not possible");
990 aggregateReplacements[k] = std::move(v);
991 }
992 }
993 // Complete steps to finalize/cleanup the caller
994 LogicalResult finalizeResult =
995 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
996 if (failed(finalizeResult)) {
997 signalPassFailure(); // error already printed w/in combineNewThenReadChain()
998 return;
999 }
1000 }
1001 }
1002};
1003
1004} // namespace
1005
1006std::unique_ptr<mlir::Pass> llzk::createInlineStructsPass() {
1007 return std::make_unique<InlineStructsPass>();
1008};
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::llvm::StringRef getFieldName()
Definition Ops.cpp.inc:900
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:650
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:593
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the FieldRefOp.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:427
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:423
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:772
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:797
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:767
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:777
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:272
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:333
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:352
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:758
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:762
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
Definition Ops.h.inc:775
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:323
::mlir::Region & getBody()
Definition Ops.h.inc:595
std::string toStringOne(const T &value)
Definition Debug.h:175
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:149
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:28
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)
Definition GraphUtil.h:17