LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
26#include "llzk/Util/Concepts.h"
27#include "llzk/Util/Debug.h"
30
31#include <mlir/Dialect/Affine/IR/AffineOps.h>
32#include <mlir/Dialect/Affine/LoopUtils.h>
33#include <mlir/Dialect/Arith/IR/Arith.h>
34#include <mlir/Dialect/SCF/IR/SCF.h>
35#include <mlir/Dialect/SCF/Utils/Utils.h>
36#include <mlir/Dialect/Utils/StaticValueUtils.h>
37#include <mlir/IR/Attributes.h>
38#include <mlir/IR/BuiltinAttributes.h>
39#include <mlir/IR/BuiltinOps.h>
40#include <mlir/IR/BuiltinTypes.h>
41#include <mlir/Interfaces/InferTypeOpInterface.h>
42#include <mlir/Pass/PassManager.h>
43#include <mlir/Support/LLVM.h>
44#include <mlir/Support/LogicalResult.h>
45#include <mlir/Transforms/DialectConversion.h>
46#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
47
48#include <llvm/ADT/APInt.h>
49#include <llvm/ADT/DenseMap.h>
50#include <llvm/ADT/DepthFirstIterator.h>
51#include <llvm/ADT/STLExtras.h>
52#include <llvm/ADT/SmallVector.h>
53#include <llvm/ADT/TypeSwitch.h>
54#include <llvm/Support/Debug.h>
55
56// Include the generated base pass class definitions.
57namespace llzk::polymorphic {
58#define GEN_PASS_DEF_FLATTENINGPASS
60} // namespace llzk::polymorphic
61
62#include "SharedImpl.h"
63
64#define DEBUG_TYPE "llzk-flatten"
65
66using namespace mlir;
67using namespace llzk;
68using namespace llzk::array;
69using namespace llzk::component;
70using namespace llzk::constrain;
71using namespace llzk::felt;
72using namespace llzk::function;
73using namespace llzk::polymorphic;
74using namespace llzk::polymorphic::detail;
75
76namespace {
77
78class ConversionTracker {
80 bool modified;
83 DenseMap<StructType, StructType> structInstantiations;
85 DenseMap<StructType, StructType> reverseInstantiations;
88 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
89
90public:
91 bool isModified() const { return modified; }
92 void resetModifiedFlag() { modified = false; }
93 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
95 void recordInstantiation(StructType oldType, StructType newType) {
96 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
98 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
99 if (forwardResult.second) {
100 // Insertion was successful
101 // ASSERT: The reverse map does not contain this mapping either
102 assert(!reverseInstantiations.contains(newType));
103 reverseInstantiations[newType] = oldType;
104 // Set the modified flag
105 modified = true;
106 } else {
107 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
108 assert(forwardResult.first->getSecond() == newType);
109 // ASSERT: The reverse mapping is already present as well
110 assert(reverseInstantiations.lookup(newType) == oldType);
112 assert(structInstantiations.size() == reverseInstantiations.size());
113 }
116 std::optional<StructType> getInstantiation(StructType oldType) const {
117 auto cachedResult = structInstantiations.find(oldType);
118 if (cachedResult != structInstantiations.end()) {
119 return cachedResult->second;
120 }
121 return std::nullopt;
122 }
123
124 /// Collect the fully-qualified names of all structs that were instantiated.
125 DenseSet<SymbolRefAttr> getInstantiatedStructNames() const {
126 DenseSet<SymbolRefAttr> instantiatedNames;
127 for (const auto &[origRemoteTy, _] : structInstantiations) {
128 instantiatedNames.insert(origRemoteTy.getNameRef());
129 }
130 return instantiatedNames;
131 }
132
133 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
134 auto res = delayedDiagnostics.find(newType);
135 if (res == delayedDiagnostics.end()) {
136 return;
137 }
138
139 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
140 for (Diagnostic &diag : res->second) {
141 // Update any notes referencing an UnknownLoc to use the CallOp location.
142 for (Diagnostic &note : diag.getNotes()) {
143 assert(note.getNotes().empty() && "notes cannot have notes attached");
144 if (llvm::isa<UnknownLoc>(note.getLocation())) {
145 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
146 }
147 }
148 // Report. Based on InFlightDiagnostic::report().
149 engine.emit(std::move(diag));
150 }
151 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
152 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
153 // `compute()` calls it will only be reported at one of those locations, not all.
154 delayedDiagnostics.erase(newType);
155 }
156
157 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
158 return delayedDiagnostics[newType];
159 }
160
163 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
164 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
165 // Check if `oTy` is a struct with a known instantiation to `nTy`
166 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
167 // Note: The values in `structInstantiations` must be no-parameter struct types
168 // so there is no need for recursive check, simple equality is sufficient.
169 if (this->structInstantiations.lookup(oldStructType) == nTy) {
170 return true;
171 }
172 }
173 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
174 // that instantiation (i.e. the parameterized version of the instantiated struct)
175 // is a more concrete unification of `oTy`.
176 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
177 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
178 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
179 return true;
180 }
181 }
182 }
183 return false;
184 };
185
186 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
187 return true;
188 }
189 LLVM_DEBUG(llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
190 << " with new type " << newType
191 << " because it does not define a compatible and more concrete type.\n";
192 );
193 return false;
194 }
195
196 template <typename T, typename U>
197 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
198 return llvm::all_of(
199 llvm::zip_equal(oldTypes, newTypes),
200 [this, &patName](std::tuple<Type, Type> oldThenNew) {
201 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
202 }
203 );
204 }
205};
206
209struct MatchFailureListener : public RewriterBase::Listener {
210 bool hadFailure = false;
211
212 ~MatchFailureListener() override {}
213
214 LogicalResult
215 notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
216 hadFailure = true;
217
218 InFlightDiagnostic diag = emitError(loc);
219 reasonCallback(*diag.getUnderlyingDiagnostic());
220 return diag; // implicitly calls `diag.report()`
221 }
222};
223
224static LogicalResult
225applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
226 bool currStepModified = false;
227 MatchFailureListener failureListener;
228 LogicalResult result = applyPatternsAndFoldGreedily(
229 modOp->getRegion(0), std::move(patterns),
230 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener}, &currStepModified
231 );
232 tracker.updateModifiedFlag(currStepModified);
233 return failure(result.failed() || failureListener.hadFailure);
234}
235
236template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
237 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
238 return isConcreteType(tyAttr.getValue(), AllowStructParams);
239 }
240 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
241 return !isDynamic(intAttr);
242 }
243 return false;
244}
245
247
248static inline bool tableOffsetIsntSymbol(FieldReadOp op) {
249 return !mlir::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
250}
251
254class StructCloner {
255 ConversionTracker &tracker_;
256 ModuleOp rootMod;
257 SymbolTableCollection symTables;
258
259 class MappedTypeConverter : public TypeConverter {
260 StructType origTy;
261 StructType newTy;
262 const DenseMap<Attribute, Attribute> &paramNameToValue;
263
264 inline Attribute convertIfPossible(Attribute a) const {
265 auto res = this->paramNameToValue.find(a);
266 return (res != this->paramNameToValue.end()) ? res->second : a;
267 }
268
269 public:
270 MappedTypeConverter(
271 StructType originalType, StructType newType,
273 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
274 )
275 : TypeConverter(), origTy(originalType), newTy(newType),
276 paramNameToValue(paramNameToInstantiatedValue) {
277
278 addConversion([](Type inputTy) { return inputTy; });
279
280 addConversion([this](StructType inputTy) {
281 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
282
283 // Check for replacement of the full type
284 if (inputTy == this->origTy) {
285 return this->newTy;
286 }
287 // Check for replacement of parameter symbol names with concrete values
288 if (ArrayAttr inputTyParams = inputTy.getParams()) {
289 SmallVector<Attribute> updated;
290 for (Attribute a : inputTyParams) {
291 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
292 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
293 } else {
294 updated.push_back(convertIfPossible(a));
295 }
296 }
297 return StructType::get(
298 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
299 );
300 }
301 // Otherwise, return the type unchanged
302 return inputTy;
303 });
304
305 addConversion([this](ArrayType inputTy) {
306 // Check for replacement of parameter symbol names with concrete values
307 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
308 if (!dimSizes.empty()) {
309 SmallVector<Attribute> updated;
310 for (Attribute a : dimSizes) {
311 updated.push_back(convertIfPossible(a));
312 }
313 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
314 }
315 // Otherwise, return the type unchanged
316 return inputTy;
317 });
318
319 addConversion([this](TypeVarType inputTy) -> Type {
320 // Check for replacement of parameter symbol name with a concrete type
321 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
322 Type convertedType = tyAttr.getValue();
323 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
324 // different struct references a parameter name from that other struct, not from the
325 // current struct so the reference would be invalid.
326 if (isConcreteType(convertedType)) {
327 return convertedType;
328 }
329 }
330 return inputTy;
331 });
332 }
333 };
334
335 template <typename Impl, typename Op, typename... HandledAttrs>
336 class SymbolUserHelper : public OpConversionPattern<Op> {
337 private:
338 const DenseMap<Attribute, Attribute> &paramNameToValue;
339
340 SymbolUserHelper(
341 TypeConverter &converter, MLIRContext *ctx, unsigned Benefit,
342 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
343 )
344 : OpConversionPattern<Op>(converter, ctx, Benefit),
345 paramNameToValue(paramNameToInstantiatedValue) {}
346
347 public:
348 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
349
350 virtual Attribute getNameAttr(Op) const = 0;
351
352 virtual LogicalResult handleDefaultRewrite(
353 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
354 ) const {
355 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
356 }
357
358 LogicalResult
359 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
360 auto res = this->paramNameToValue.find(getNameAttr(op));
361 if (res == this->paramNameToValue.end()) {
362 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] no instantiation for " << op << '\n');
363 return failure();
364 }
365 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
366 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
367
368 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
369 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
370 }))),
371 ...);
372
373 return TS.Default([&](Attribute a) {
374 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
375 });
376 }
377 friend Impl;
378 };
379
380 class ClonedStructConstReadOpPattern
381 : public SymbolUserHelper<
382 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
383 SmallVector<Diagnostic> &diagnostics;
384
385 using super =
386 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
387
388 public:
389 ClonedStructConstReadOpPattern(
390 TypeConverter &converter, MLIRContext *ctx,
391 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
392 SmallVector<Diagnostic> &instantiationDiagnostics
393 )
394 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
395 // instead of the GeneralTypeReplacePattern<ConstReadOp> from newGeneralRewritePatternSet().
396 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue),
397 diagnostics(instantiationDiagnostics) {}
398
399 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
400
401 LogicalResult handleRewrite(
402 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
403 ) const {
404 APInt attrValue = a.getValue();
405 Type origResTy = op.getType();
406 if (llvm::isa<FeltType>(origResTy)) {
408 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
409 );
410 return success();
411 }
412
413 if (llvm::isa<IndexType>(origResTy)) {
415 return success();
416 }
417
418 if (origResTy.isSignlessInteger(1)) {
419 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
420 if (attrValue.isZero()) {
421 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
422 return success();
423 }
424 if (!attrValue.isOne()) {
425 Location opLoc = op.getLoc();
426 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
427 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
428 if (getContext()->shouldPrintOpOnDiagnostic()) {
429 diag.attachNote(opLoc) << "see current operation: " << *op;
430 }
431 diag.attachNote(UnknownLoc::get(getContext()))
432 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \""
433 << sym << "\" for this call";
434 diagnostics.push_back(std::move(diag));
435 }
436 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
437 return success();
438 }
439 return op->emitOpError().append("unexpected result type ", origResTy);
440 }
441
442 LogicalResult handleRewrite(
443 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
444 ) const {
445 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
446 return success();
447 }
448 };
449
450 class ClonedStructFieldReadOpPattern
451 : public SymbolUserHelper<
452 ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr> {
453 using super =
454 SymbolUserHelper<ClonedStructFieldReadOpPattern, FieldReadOp, IntegerAttr, FeltConstAttr>;
455
456 public:
457 ClonedStructFieldReadOpPattern(
458 TypeConverter &converter, MLIRContext *ctx,
459 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
460 )
461 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
462 // instead of the GeneralTypeReplacePattern<FieldReadOp> from newGeneralRewritePatternSet().
463 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue) {}
464
465 Attribute getNameAttr(FieldReadOp op) const override {
466 return op.getTableOffset().value_or(nullptr);
467 }
468
469 template <typename Attr>
470 LogicalResult handleRewrite(
471 Attribute, FieldReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
472 ) const {
473 rewriter.modifyOpInPlace(op, [&]() {
474 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
475 });
476
477 return success();
478 }
479
480 LogicalResult matchAndRewrite(
481 FieldReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
482 ) const override {
483 if (tableOffsetIsntSymbol(op)) {
484 return failure();
485 }
486
487 return super::matchAndRewrite(op, adaptor, rewriter);
488 }
489 };
490
491 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
492 // Find the StructDefOp for the original StructType
493 FailureOr<SymbolLookupResult<StructDefOp>> r = typeAtCaller.getDefinition(symTables, rootMod);
494 if (failed(r)) {
495 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
496 return failure(); // getDefinition() already emits a sufficient error message
497 }
498
499 StructDefOp origStruct = r->get();
500 StructType typeAtDef = origStruct.getType();
501 MLIRContext *ctx = origStruct.getContext();
502
503 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
504 DenseMap<Attribute, Attribute> paramNameToConcrete;
505 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
506 // where the original attribute from the current instantiation site was not concrete. This is
507 // used for generating the new struct name. See `BuildShortTypeString::from()`.
508 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
509 // Parameter list for the new StructDefOp containing the names that must be preserved because
510 // they were not assigned concrete values at the current instantiation site.
511 ArrayAttr reducedParamNameList = nullptr;
512 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
513 ArrayAttr reducedCallerParams = nullptr;
514 {
515 ArrayAttr paramNames = typeAtDef.getParams();
516
517 // pre-conditions
518 assert(!isNullOrEmpty(paramNames));
519 assert(paramNames.size() == typeAtCallerParams.size());
520
521 SmallVector<Attribute> remainingNames;
522 SmallVector<Attribute> nonConcreteParams;
523 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
524 Attribute next = typeAtCallerParams[i];
525 if (isConcreteAttr<false>(next)) {
526 paramNameToConcrete[paramNames[i]] = next;
527 attrsForInstantiatedNameSuffix.push_back(next);
528 } else {
529 remainingNames.push_back(paramNames[i]);
530 nonConcreteParams.push_back(next);
531 attrsForInstantiatedNameSuffix.push_back(nullptr);
532 }
533 }
534 // post-conditions
535 assert(remainingNames.size() == nonConcreteParams.size());
536 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
537 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
538
539 if (paramNameToConcrete.empty()) {
540 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
541 return failure();
542 }
543 if (!remainingNames.empty()) {
544 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
545 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
546 }
547 }
548
549 // Clone the original struct, apply the new name, and set the parameter list of the new struct
550 // to contain only those that did not have concrete instantiated values.
551 StructDefOp newStruct = origStruct.clone();
552 newStruct.setConstParamsAttr(reducedParamNameList);
554 typeAtCaller.getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
555 ));
556
557 // Insert 'newStruct' into the parent ModuleOp of the original StructDefOp. Use the
558 // `SymbolTable::insert()` function directly so that the name will be made unique.
559 ModuleOp parentModule = llvm::cast<ModuleOp>(origStruct.getParentOp());
560 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
561 // Retrieve the new type AFTER inserting since the name may be appended to make it unique and
562 // use the remaining non-concrete parameters from the original type.
563 StructType newRemoteType = newStruct.getType(reducedCallerParams);
564 LLVM_DEBUG({
565 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
566 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
567 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
568 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
569 });
570
571 // Within the new struct, replace all references to the original StructType (i.e. the
572 // locally-parameterized version) with the new locally-parameterized StructType,
573 // and replace all uses of the removed struct parameters with the concrete values.
574 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
575 ConversionTarget target =
576 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
577 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
578 // Legal if it's not in the map of concrete attribute instantiations
579 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
580 });
581
582 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
583 patterns.add<ClonedStructConstReadOpPattern>(
584 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newRemoteType)
585 );
586 patterns.add<ClonedStructFieldReadOpPattern>(tyConv, ctx, paramNameToConcrete);
587 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
588 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
589 return failure();
590 }
591 return newRemoteType;
592 }
593
594public:
595 StructCloner(ConversionTracker &tracker, ModuleOp root)
596 : tracker_(tracker), rootMod(root), symTables() {}
597
598 FailureOr<StructType> createInstantiatedClone(StructType orig) {
599 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
600 if (ArrayAttr params = orig.getParams()) {
601 return genClone(orig, params.getValue());
602 }
603 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
604 return failure();
605 }
606};
607
608class ParameterizedStructUseTypeConverter : public TypeConverter {
609 ConversionTracker &tracker_;
610 StructCloner cloner;
611
612public:
613 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
614 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
615
616 addConversion([](Type inputTy) { return inputTy; });
617
618 addConversion([this](StructType inputTy) -> StructType {
619 // First check for a cached entry
620 if (auto opt = tracker_.getInstantiation(inputTy)) {
621 return opt.value();
622 }
623
624 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
625 // done, return the original type to indicate that it's still legal (for this step at least).
626 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
627 if (failed(cloneRes)) {
628 return inputTy;
629 }
630 StructType newTy = cloneRes.value();
631 LLVM_DEBUG(
632 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
633 << " as " << newTy << '\n'
634 );
635 tracker_.recordInstantiation(inputTy, newTy);
636 return newTy;
637 });
638
639 addConversion([this](ArrayType inputTy) {
640 return inputTy.cloneWith(convertType(inputTy.getElementType()));
641 });
642 }
643};
644
645class CallStructFuncPattern : public OpConversionPattern<CallOp> {
646 ConversionTracker &tracker_;
647
648public:
649 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
650 // Must use higher benefit than CallOpClassReplacePattern so this pattern will be applied
651 // instead of the CallOpClassReplacePattern from newGeneralRewritePatternSet().
652 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/2), tracker_(tracker) {}
653
654 LogicalResult matchAndRewrite(CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter)
655 const override {
656 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
657
658 // Convert the result types of the CallOp
659 SmallVector<Type> newResultTypes;
660 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
661 return op->emitError("Could not convert Op result types.");
662 }
663 LLVM_DEBUG({
664 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
665 << debug::toStringList(newResultTypes) << '\n';
666 });
667
668 // Update the callee to reflect the new struct target if necessary. These checks are based on
669 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
670 // Instead they must come from the converted versions.
671 SymbolRefAttr calleeAttr = op.getCalleeAttr();
672 if (op.calleeIsStructCompute()) {
673 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
674 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
675 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
676 tracker_.reportDelayedDiagnostics(newStTy, op);
677 }
678 } else if (op.calleeIsStructConstrain()) {
679 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
680 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
681 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
682 }
683 }
684
685 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
687 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
688 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
689 );
690 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
691 return success();
692 }
693};
694
695// This one ensures FieldDefOp types are converted even if there are no reads/writes to them.
696class FieldDefOpPattern : public OpConversionPattern<FieldDefOp> {
697public:
698 FieldDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
699 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
700 // instead of the GeneralTypeReplacePattern<FieldDefOp> from newGeneralRewritePatternSet().
701 : OpConversionPattern<FieldDefOp>(converter, ctx, /*benefit=*/2) {}
702
703 LogicalResult matchAndRewrite(
704 FieldDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
705 ) const override {
706 LLVM_DEBUG(llvm::dbgs() << "[FieldDefOpPattern] FieldDefOp: " << op << '\n');
707
708 Type oldFieldType = op.getType();
709 Type newFieldType = getTypeConverter()->convertType(oldFieldType);
710 if (oldFieldType == newFieldType) {
711 // nothing changed
712 return failure();
713 }
714 rewriter.modifyOpInPlace(op, [&op, &newFieldType]() { op.setType(newFieldType); });
715 return success();
716 }
717};
718
719LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
720 MLIRContext *ctx = modOp.getContext();
721 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
722 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
723 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
724 patterns.add<CallStructFuncPattern, FieldDefOpPattern>(tyConv, ctx, tracker);
725 return applyPartialConversion(modOp, target, std::move(patterns));
726}
727
728} // namespace Step1_InstantiateStructs
729
730namespace Step2_Unroll {
731
732// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
733template <HasInterface<LoopLikeOpInterface> OpClass>
734class LoopUnrollPattern : public OpRewritePattern<OpClass> {
735public:
736 using OpRewritePattern<OpClass>::OpRewritePattern;
737
738 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
739 if (auto maybeConstant = getConstantTripCount(loopOp)) {
740 uint64_t tripCount = *maybeConstant;
741 if (tripCount == 0) {
742 rewriter.eraseOp(loopOp);
743 return success();
744 } else if (tripCount == 1) {
745 return loopOp.promoteIfSingleIteration(rewriter);
746 }
747 return loopUnrollByFactor(loopOp, tripCount);
748 }
749 return failure();
750 }
751
752private:
755 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
756 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
757 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
758 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
759 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
760 return std::nullopt;
761 }
762 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
763 }
764};
765
766LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
767 MLIRContext *ctx = modOp.getContext();
768 RewritePatternSet patterns(ctx);
769 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
770 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
771
772 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
773}
774} // namespace Step2_Unroll
775
777
778// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
779// version uses a basic loop instead of llvm::map_to_vector().
780std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
781 SmallVector<int64_t> res;
782 for (OpFoldResult ofr : ofrs) {
783 std::optional<int64_t> cv = getConstantIntValue(ofr);
784 if (!cv.has_value()) {
785 return std::nullopt;
786 }
787 res.push_back(cv.value());
788 }
789 return res;
790}
791
792struct AffineMapFolder {
793 struct Input {
794 OperandRangeRange mapOpGroups;
795 DenseI32ArrayAttr dimsPerGroup;
796 ArrayRef<Attribute> paramsOfStructTy;
797 };
798
799 struct Output {
800 SmallVector<SmallVector<Value>> mapOpGroups;
801 SmallVector<int32_t> dimsPerGroup;
802 SmallVector<Attribute> paramsOfStructTy;
803 };
804
805 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
806 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
807 return ValueRange(grp);
808 });
809 }
810
811 static LogicalResult
812 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
813 if (in.mapOpGroups.empty()) {
814 // No affine map operands so nothing to do
815 return failure();
816 }
817
818 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
819 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
820
821 size_t idx = 0; // index in `mapOpGroups`, i.e. the number of AffineMapAttr encountered
822 for (Attribute sizeAttr : in.paramsOfStructTy) {
823 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
824 ValueRange currMapOps = in.mapOpGroups[idx++];
825 LLVM_DEBUG(
826 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
827 << '\n'
828 );
829 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
830 LLVM_DEBUG(
831 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
832 << debug::toStringList(currMapOpsCast) << '\n'
833 );
834 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
835 SmallVector<Attribute> result;
836 bool hasPoison = false; // indicates divide by 0 or mod by <1
837 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
838 return rewriter.getIndexAttr(v);
839 });
840 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
841 if (hasPoison) {
842 LLVM_DEBUG(op->emitRemark().append(
843 "Cannot fold affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
844 " due to divide by 0 or modulus with negative divisor"
845 ));
846 return failure();
847 }
848 if (failed(foldResult)) {
849 LLVM_DEBUG(op->emitRemark().append(
850 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " failed"
851 ));
852 return failure();
853 }
854 if (result.size() != 1) {
855 LLVM_DEBUG(op->emitRemark().append(
856 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " produced ",
857 result.size(), " results but expected 1"
858 ));
859 return failure();
860 }
861 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
862 out.paramsOfStructTy.push_back(result[0]);
863 continue;
864 }
865 // If affine but not foldable, preserve the map ops
866 out.mapOpGroups.emplace_back(currMapOps);
867 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
868 }
869 // If not affine and foldable, preserve the original
870 out.paramsOfStructTy.push_back(sizeAttr);
871 }
872 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
873 assert(
874 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
875 "produced wrong number of dimensions"
876 );
877
878 return success();
879 }
880};
881
883class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
884 ConversionTracker &tracker_;
885
886public:
887 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
888 : OpRewritePattern(ctx), tracker_(tracker) {}
889
890 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
891 ArrayType oldResultType = op.getType();
892
893 AffineMapFolder::Output out;
894 AffineMapFolder::Input in = {
895 op.getMapOperands(),
897 oldResultType.getDimensionSizes(),
898 };
899 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
900 return failure();
901 }
902
903 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
904 if (newResultType == oldResultType) {
905 // nothing changed
906 return failure();
907 }
908 // ASSERT: folding only preserves the original Attribute or converts affine to integer
909 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
910 LLVM_DEBUG(
911 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
912 << newResultType << " in \"" << op << "\"\n"
913 );
915 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
916 );
917 return success();
918 }
919};
920
922class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
923 ConversionTracker &tracker_;
924
925public:
926 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
927 : OpRewritePattern(ctx), tracker_(tracker) {}
928
929 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
930 if (!op.calleeIsStructCompute()) {
931 // this pattern only applies when the callee is "compute()" within a struct
932 return failure();
933 }
934 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
936 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
937 ArrayAttr params = oldRetTy.getParams();
938 if (isNullOrEmpty(params)) {
939 // nothing to do if the StructType is not parameterized
940 return failure();
941 }
942
943 AffineMapFolder::Output out;
944 AffineMapFolder::Input in = {
945 op.getMapOperands(),
947 params.getValue(),
948 };
949 if (!in.mapOpGroups.empty()) {
950 // If there are affine map operands, attempt to fold them to a constant.
951 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
952 return failure();
953 }
954 LLVM_DEBUG({
955 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
956 });
957 } else {
958 // If there are no affine map operands, attempt to refine the result type of the CallOp using
959 // the function argument types and the type of the target function.
960 auto callArgTypes = op.getArgOperands().getTypes();
961 if (callArgTypes.empty()) {
962 // no refinement possible if no function arguments
963 return failure();
964 }
965 SymbolTableCollection tables;
966 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
967 if (failed(lookupRes)) {
968 return failure();
969 }
970 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
971 return failure();
972 }
973 LLVM_DEBUG({
974 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
975 "result type params: "
976 << debug::toStringList(out.paramsOfStructTy) << '\n';
977 });
978 }
979
980 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
981 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
982 if (newRetTy == oldRetTy) {
983 // nothing changed
984 return failure();
985 }
986 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
987 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
988 // the conversion here is illegal it means there is a type conflict within the LLZK code that
989 // prevents instantiation of the struct with the requested type.
990 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
991 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
992 diag.append(
993 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
994 ", but found ", oldRetTy
995 );
996 });
997 }
998 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1000 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1001 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1002 );
1003 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1004 return success();
1005 }
1006
1007private:
1010 inline LogicalResult instantiateViaTargetType(
1011 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1012 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1013 ) const {
1014 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1015 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1016 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1017 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1018
1019 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1020 // Nothing can change if everything is already concrete
1021 return failure();
1022 }
1023
1024 LLVM_DEBUG({
1025 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1026 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1027 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1028 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1029 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1030 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1031 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1032 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1033 });
1034
1035 UnificationMap unifications;
1036 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1037 assert(unifies && "should have been checked by verifiers");
1038
1039 LLVM_DEBUG({
1040 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1041 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1042 });
1043
1044 // Check for LHS SymRef (i.e. from the target function) that have RHS concrete Attributes (i.e.
1045 // from the call argument types) without any struct parameters (because the type with concrete
1046 // struct parameters will be used to instantiate the target struct rather than the fully
1047 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1048 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1049 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1050 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1051 [&unifications](std::tuple<Attribute, Attribute> p) {
1052 Attribute fromCall = std::get<1>(p);
1053 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1054 // non-parameterized concrete unification for the target struct parameter symbol.
1055 if (!isConcreteAttr<>(fromCall)) {
1056 Attribute fromTgt = std::get<0>(p);
1057 LLVM_DEBUG({
1058 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1059 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1060 });
1061 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1062 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1063 if (it != unifications.end()) {
1064 Attribute unifiedAttr = it->second;
1065 LLVM_DEBUG({
1066 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1067 });
1068 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1069 return unifiedAttr;
1070 }
1071 }
1072 }
1073 return fromCall;
1074 }
1075 );
1076
1077 out.paramsOfStructTy = newReturnStructParams;
1078 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1079 assert(out.mapOpGroups.empty() && "post-condition");
1080 assert(out.dimsPerGroup.empty() && "post-condition");
1081 return success();
1082 }
1083};
1084
1085LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1086 MLIRContext *ctx = modOp.getContext();
1087 RewritePatternSet patterns(ctx);
1088 patterns.add<
1089 InstantiateAtCreateArrayOp, // CreateArrayOp
1090 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1091 >(ctx, tracker);
1092
1093 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1094}
1095
1096} // namespace Step3_InstantiateAffineMaps
1097
1099
1101class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1102 ConversionTracker &tracker_;
1103
1104public:
1105 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1106 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1107
1108 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1109 Value createResult = op.getResult();
1110 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1111 assert(createResultType && "CreateArrayOp must produce ArrayType");
1112 Type oldResultElemType = createResultType.getElementType();
1113
1114 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1115 // element type is different.
1116 Type newResultElemType = nullptr;
1117 for (Operation *user : createResult.getUsers()) {
1118 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1119 if (writeOp.getArrRef() != createResult) {
1120 continue;
1121 }
1122 Type writeRValueType = writeOp.getRvalue().getType();
1123 if (writeRValueType == oldResultElemType) {
1124 continue;
1125 }
1126 if (newResultElemType && newResultElemType != writeRValueType) {
1127 LLVM_DEBUG(
1128 llvm::dbgs()
1129 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1130 << newResultElemType << " vs " << writeRValueType << '\n'
1131 );
1132 return failure();
1133 }
1134 newResultElemType = writeRValueType;
1135 }
1136 }
1137 if (!newResultElemType) {
1138 // no replacement type found
1139 return failure();
1140 }
1141 if (!tracker_.isLegalConversion(
1142 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1143 )) {
1144 return failure();
1145 }
1146 ArrayType newType = createResultType.cloneWith(newResultElemType);
1147 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1148 LLVM_DEBUG(
1149 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1150 );
1151 return success();
1152 }
1153};
1154
1155namespace {
1156
1157LogicalResult updateArrayElemFromArrAccessOp(
1158 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1159 PatternRewriter &rewriter
1160) {
1161 ArrayType oldArrType = op.getArrRefType();
1162 if (oldArrType.getElementType() == scalarElemTy) {
1163 return failure(); // no change needed
1164 }
1165 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1166 if (oldArrType == newArrType ||
1167 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1168 return failure();
1169 }
1170 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1171 LLVM_DEBUG(
1172 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1173 );
1174 return success();
1175}
1176
1177} // namespace
1178
1179class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1180 ConversionTracker &tracker_;
1181
1182public:
1183 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1184 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1185
1186 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1187 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1188 }
1189};
1190
1191class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1192 ConversionTracker &tracker_;
1193
1194public:
1195 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1196 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1197
1198 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1199 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1200 }
1201};
1202
1204class UpdateFieldDefTypeFromWrite final : public OpRewritePattern<FieldDefOp> {
1205 ConversionTracker &tracker_;
1206
1207public:
1208 UpdateFieldDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1209 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1210
1211 LogicalResult matchAndRewrite(FieldDefOp op, PatternRewriter &rewriter) const override {
1212 // Find all uses of the field symbol name within its parent struct.
1213 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(op);
1214 assert(succeeded(parentRes) && "FieldDefOp parent is always StructDefOp"); // per ODS def
1215
1216 // If the symbol is used by a FieldWriteOp with a different result type then change
1217 // the type of the FieldDefOp to match the FieldWriteOp result type.
1218 Type newType = nullptr;
1219 if (auto fieldUsers = SymbolTable::getSymbolUses(op, parentRes.value())) {
1220 std::optional<Location> newTypeLoc = std::nullopt;
1221 for (SymbolTable::SymbolUse symUse : fieldUsers.value()) {
1222 if (FieldWriteOp writeOp = llvm::dyn_cast<FieldWriteOp>(symUse.getUser())) {
1223 Type writeToType = writeOp.getVal().getType();
1224 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] checking " << writeOp << '\n');
1225 if (!newType) {
1226 // If a new type has not yet been discovered, store the new type.
1227 newType = writeToType;
1228 newTypeLoc = writeOp.getLoc();
1229 } else if (writeToType != newType) {
1230 // Typically, there will only be one write for each field of a struct but do not rely on
1231 // that assumption. If multiple writes with a different types A and B are found where
1232 // A->B is a legal conversion (i.e. more concrete unification), then it is safe to use
1233 // type B with the assumption that the write with type A will be updated by another
1234 // pattern to also use type B.
1235 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateFieldDefTypeFromWrite")) {
1236 if (tracker_.isLegalConversion(newType, writeToType, "UpdateFieldDefTypeFromWrite")) {
1237 // 'writeToType' is the more concrete type
1238 newType = writeToType;
1239 newTypeLoc = writeOp.getLoc();
1240 } else {
1241 // Give an error if the types are incompatible.
1242 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1243 diag.append(
1244 "Cannot update type of '", FieldDefOp::getOperationName(),
1245 "' because there are multiple '", FieldWriteOp::getOperationName(),
1246 "' with different value types"
1247 );
1248 if (newTypeLoc) {
1249 diag.attachNote(*newTypeLoc).append("type written here is ", newType);
1250 }
1251 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1252 });
1253 }
1254 }
1255 }
1256 }
1257 }
1258 }
1259 if (!newType || newType == op.getType()) {
1260 // nothing changed
1261 return failure();
1262 }
1263 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateFieldDefTypeFromWrite")) {
1264 return failure();
1265 }
1266 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1267 LLVM_DEBUG(llvm::dbgs() << "[UpdateFieldDefTypeFromWrite] updated type of " << op << '\n');
1268 return success();
1269 }
1270};
1271
1272namespace {
1273
1274SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1275 SmallVector<std::unique_ptr<Region>> newRegions;
1276 for (Region &region : op->getRegions()) {
1277 auto newRegion = std::make_unique<Region>();
1278 newRegion->takeBody(region);
1279 newRegions.push_back(std::move(newRegion));
1280 }
1281 return newRegions;
1282}
1283
1284} // namespace
1285
1288class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1289 ConversionTracker &tracker_;
1290
1291public:
1292 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1293 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1294
1295 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1296 SmallVector<Type, 1> inferredResultTypes;
1297 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1298 LogicalResult result = retTypeFn.inferReturnTypes(
1299 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1300 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1301 );
1302 if (failed(result)) {
1303 return failure();
1304 }
1305 if (op->getResultTypes() == inferredResultTypes) {
1306 // nothing changed
1307 return failure();
1308 }
1309 if (!tracker_.areLegalConversions(
1310 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1311 )) {
1312 return failure();
1313 }
1314
1315 // Move nested region bodies and replace the original op with the updated types list.
1316 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1317 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1318 Operation *newOp = rewriter.create(
1319 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1320 op->getAttrs(), op->getSuccessors(), newRegions
1321 );
1322 rewriter.replaceOp(op, newOp);
1323 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1324 return success();
1325 }
1326};
1327
1329class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1330 ConversionTracker &tracker_;
1331
1332public:
1333 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1334 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1335
1336 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1337 Region &body = op.getFunctionBody();
1338 if (body.empty()) {
1339 return failure();
1340 }
1341 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1342 assert(retOp && "final op in body region must be return");
1343 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1344
1345 FunctionType oldFuncTy = op.getFunctionType();
1346 if (oldFuncTy.getResults() == tyFromReturnOp) {
1347 // nothing changed
1348 return failure();
1349 }
1350 if (!tracker_.areLegalConversions(
1351 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1352 )) {
1353 return failure();
1354 }
1355
1356 rewriter.modifyOpInPlace(op, [&]() {
1357 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1358 });
1359 LLVM_DEBUG(
1360 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
1361 << oldFuncTy << " to " << op.getFunctionType() << '\n'
1362 );
1363 return success();
1364 }
1365};
1366
1371class UpdateGlobalCallOpTypes final : public OpRewritePattern<CallOp> {
1372 ConversionTracker &tracker_;
1373
1374public:
1375 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1376 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1377
1378 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1379 SymbolTableCollection tables;
1380 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1381 if (failed(lookupRes)) {
1382 return failure();
1383 }
1384 FuncDefOp targetFunc = lookupRes->get();
1385 if (targetFunc.isInStruct()) {
1386 // this pattern only applies when the callee is NOT in a struct
1387 return failure();
1388 }
1389 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
1390 // nothing changed
1391 return failure();
1392 }
1393 if (!tracker_.areLegalConversions(
1394 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
1395 "UpdateGlobalCallOpTypes"
1396 )) {
1397 return failure();
1398 }
1399
1400 LLVM_DEBUG(llvm::dbgs() << "[UpdateGlobalCallOpTypes] replaced " << op);
1401 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
1402 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1403 return success();
1404 }
1405};
1406
1407namespace {
1408
1409LogicalResult updateFieldRefValFromFieldDef(
1410 FieldRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
1411) {
1412 SymbolTableCollection tables;
1413 auto def = op.getFieldDefOp(tables);
1414 if (failed(def)) {
1415 return failure();
1416 }
1417 Type oldResultType = op.getVal().getType();
1418 Type newResultType = def->get().getType();
1419 if (oldResultType == newResultType ||
1420 !tracker.isLegalConversion(oldResultType, newResultType, "updateFieldRefValFromFieldDef")) {
1421 return failure();
1422 }
1423 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
1424 LLVM_DEBUG(
1425 llvm::dbgs() << "[updateFieldRefValFromFieldDef] updated value type in " << op << '\n'
1426 );
1427 return success();
1428}
1429
1430} // namespace
1431
1433class UpdateFieldReadValFromDef final : public OpRewritePattern<FieldReadOp> {
1434 ConversionTracker &tracker_;
1435
1436public:
1437 UpdateFieldReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1438 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1439
1440 LogicalResult matchAndRewrite(FieldReadOp op, PatternRewriter &rewriter) const override {
1441 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1442 }
1443};
1444
1446class UpdateFieldWriteValFromDef final : public OpRewritePattern<FieldWriteOp> {
1447 ConversionTracker &tracker_;
1448
1449public:
1450 UpdateFieldWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1451 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1452
1453 LogicalResult matchAndRewrite(FieldWriteOp op, PatternRewriter &rewriter) const override {
1454 return updateFieldRefValFromFieldDef(op, tracker_, rewriter);
1455 }
1456};
1457
1458LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1459 MLIRContext *ctx = modOp.getContext();
1460 RewritePatternSet patterns(ctx);
1461 patterns.add<
1462 // Benefit of this one must be higher than rules that would propagate the type in the opposite
1463 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
1464 // benefit = 6
1465 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
1466 // benefit = 3
1467 UpdateGlobalCallOpTypes, // CallOp, targeting non-struct functions
1468 UpdateFuncTypeFromReturn, // FuncDefOp
1469 UpdateNewArrayElemFromWrite, // CreateArrayOp
1470 UpdateArrayElemFromArrRead, // ReadArrayOp
1471 UpdateArrayElemFromArrWrite, // WriteArrayOp
1472 UpdateFieldDefTypeFromWrite, // FieldDefOp
1473 UpdateFieldReadValFromDef, // FieldReadOp
1474 UpdateFieldWriteValFromDef // FieldWriteOp
1475 >(ctx, tracker);
1476
1477 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1478}
1479} // namespace Step4_PropagateTypes
1480
1481namespace Step5_Cleanup {
1482
1483template <typename OpClass> class EraseOpPattern : public OpConversionPattern<OpClass> {
1484public:
1485 EraseOpPattern(MLIRContext *ctx) : OpConversionPattern<OpClass>(ctx) {}
1486
1487 LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor, ConversionPatternRewriter &rewriter)
1488 const override {
1489 rewriter.eraseOp(op);
1490 return success();
1491 }
1492};
1493
1494LogicalResult run(ModuleOp modOp, const ConversionTracker &tracker) {
1495 FailureOr<ModuleOp> topRoot = getTopRootModule(modOp);
1496 if (failed(topRoot)) {
1497 return failure();
1498 }
1499
1500 // Use a conversion to erase instantiated structs if they have no other references.
1501 //
1502 // TODO: There's a chance the "no other references" criteria will leave some behind when running
1503 // only a single pass of this because they may reference each other. Maybe I can check if the
1504 // references are only located within another struct in the list, but would have to do a deep
1505 // deep lookup to ensure no references and avoid infinite loop back on self. The CallGraphAnalysis
1506 // is not sufficient because it looks only at calls but there could be (although unlikely) a
1507 // FieldDefOp referencing a struct type despite having no calls to that struct's functions.
1508 //
1509 // TODO: There's another scenario that leaves some behind. Once a StructDefOp is visited and
1510 // considered legal, that decision cannot be reversed. Hence, StructDefOp that become illegal only
1511 // after removing another one that uses it will not be removed. See
1512 // test/Dialect/LLZK/instantiate_structs_affine_pass.llzk
1513 // One idea is to use one of the `SymbolTable::getSymbolUses` functions starting from a struct
1514 // listed in `instantiatedNames` to determine if it is reachable from some other struct that is
1515 // NOT listed there and remove it if not. For efficiency, this reachability information can be
1516 // pre-computed and or cached.
1517 //
1518 DenseSet<SymbolRefAttr> instantiatedNames = tracker.getInstantiatedStructNames();
1519 auto isLegalStruct = [&](bool emitWarning, StructDefOp op) {
1520 if (instantiatedNames.contains(op.getType().getNameRef())) {
1521 if (!hasUsesWithin(op, *topRoot)) {
1522 // Parameterized struct with no uses is illegal, i.e. should be removed.
1523 return false;
1524 }
1525 if (emitWarning) {
1526 op.emitWarning("Parameterized struct still has uses!").report();
1527 }
1528 }
1529 return true;
1530 };
1531
1532 // Perform the conversion, i.e. remove StructDefOp that were instantiated and are unused.
1533 MLIRContext *ctx = modOp.getContext();
1534 RewritePatternSet patterns(ctx);
1535 patterns.add<EraseOpPattern<StructDefOp>>(ctx);
1536 ConversionTarget target = newBaseTarget(ctx);
1537 // TODO: Here are some thoughts from LLZK planning:
1538 // - Flattening could remove all structs that are not instantiated and the backend would have to
1539 // not run flattening if it wants to keep any templated structs.
1540 // - Alternatively, flattening could define different conversion targets at runtime based on
1541 // config flags to indicate if some should be flattened and others should not. This could be a
1542 // flag that indicates allow/restrict globally or based on some struct criteria, like names.
1543 //
1544 // target.addIllegalDialect<polymorphic::PolymorphicDialect>();
1545 target.addDynamicallyLegalOp<StructDefOp>(std::bind_front(isLegalStruct, false));
1546 if (failed(applyFullConversion(modOp, target, std::move(patterns)))) {
1547 return failure();
1548 }
1549
1550 // Warn about any structs that were instantiated but still have uses elsewhere.
1551 modOp->walk([&](StructDefOp op) {
1552 isLegalStruct(true, op);
1553 return WalkResult::skip(); // StructDefOp cannot be nested
1554 });
1555
1556 return success();
1557}
1558
1559} // namespace Step5_Cleanup
1560
1561class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<FlatteningPass> {
1562
1563 static constexpr unsigned LIMIT = 1000;
1564
1565 inline LogicalResult runOn(ModuleOp modOp) {
1566 {
1567 // Preliminary step: remove empty parameter lists from structs
1568 OpPassManager nestedPM(ModuleOp::getOperationName());
1569 nestedPM.addPass(createEmptyParamListRemoval());
1570 if (failed(runPipeline(nestedPM, modOp))) {
1571 return failure();
1572 }
1573 }
1574
1575 ConversionTracker tracker;
1576 unsigned loopCount = 0;
1577 do {
1578 ++loopCount;
1579 if (loopCount > LIMIT) {
1580 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << LIMIT << " iterations!\n";
1581 return failure();
1582 }
1583 tracker.resetModifiedFlag();
1584
1585 // Find calls to "compute()" that return a parameterized struct and replace it to call a
1586 // flattened version of the struct that has parameters replaced with the constant values.
1587 // Create the necessary instantiated/flattened struct in the same location as the original.
1588 if (failed(Step1_InstantiateStructs::run(modOp, tracker))) {
1589 llvm::errs() << DEBUG_TYPE << " failed while replacing concrete-parameter struct types\n";
1590 return failure();
1591 }
1592
1593 // Unroll loops with known iterations.
1594 if (failed(Step2_Unroll::run(modOp, tracker))) {
1595 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
1596 return failure();
1597 }
1598
1599 // Instantiate affine_map parameters of StructType and ArrayType.
1600 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
1601 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
1602 return failure();
1603 }
1604
1605 // Propagate updated types using the semantics of various ops.
1606 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
1607 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
1608 return failure();
1609 }
1610
1611 LLVM_DEBUG(if (tracker.isModified()) {
1612 llvm::dbgs() << "=====================================================================\n";
1613 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << " \n";
1614 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1615 llvm::dbgs() << "=====================================================================\n";
1616 });
1617 } while (tracker.isModified());
1618
1619 // Remove the parameterized StructDefOp that were instantiated.
1620 if (failed(Step5_Cleanup::run(modOp, tracker))) {
1621 llvm::errs() << DEBUG_TYPE
1622 << " failed while removing parameterized structs that were replaced with "
1623 "instantiated versions\n";
1624 return failure();
1625 }
1626
1627 return success();
1628 }
1629
1630 void runOnOperation() override {
1631 ModuleOp modOp = getOperation();
1632 if (failed(runOn(modOp))) {
1633 LLVM_DEBUG({
1634 // If the pass failed, dump the current IR.
1635 llvm::dbgs() << "=====================================================================\n";
1636 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << " \n";
1637 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1638 llvm::dbgs() << "=====================================================================\n";
1639 });
1640 signalPassFailure();
1641 }
1642 }
1643};
1644
1645} // namespace
1646
1648 return std::make_unique<FlatteningPass>();
1649};
#define DEBUG_TYPE
Common private implementation for poly dialect passes.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:649
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:412
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.cpp.inc:438
::mlir::Value getResult()
Definition Ops.cpp.inc:1467
::mlir::Value getRvalue()
Definition Ops.cpp.inc:1772
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:617
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:292
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:1112
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.cpp.inc:1143
::mlir::FailureOr< SymbolLookupResult< FieldDefOp > > getFieldDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the field referenced in this op.
Definition Ops.cpp:509
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the FieldRefOp.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:729
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
Definition Ops.cpp:142
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:919
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1971
void setConstParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.cpp.inc:1975
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:69
::mlir::ArrayAttr getParams() const
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op) const
Gets the struct op that defines this struct.
Definition Types.cpp:39
::mlir::OperandRangeRange getMapOperands()
Definition Ops.cpp.inc:228
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:694
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:700
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:688
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:526
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.cpp.inc:531
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.cpp.inc:522
::mlir::Operation::operand_range getArgOperands()
Definition Ops.cpp.inc:224
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:1095
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Returns the argument types of this function.
Definition Ops.h.inc:591
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:309
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1086
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:622
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:619
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1130
::mlir::Operation::operand_range getOperands()
Definition Ops.cpp.inc:1322
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.cpp.inc:662
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:121
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:239
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:268
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:126
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
std::unique_ptr< mlir::Pass > createFlatteningPass()
std::unique_ptr< mlir::Pass > createEmptyParamListRemoval()
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:221
bool isConcreteType(Type type, bool allowStructParams)
bool hasUsesWithin(Operation *symbol, Operation *from)
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:251
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:181
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
FailureOr< ModuleOp > getTopRootModule(Operation *from)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:255
bool isDynamic(IntegerAttr intAttr)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:32
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
int64_t fromAPInt(llvm::APInt i)