LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
TypeHelper.cpp
Go to the documentation of this file.
1//===-- TypeHelper.cpp ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
19
20#include <llvm/ADT/TypeSwitch.h>
21
22using namespace mlir;
23
24namespace llzk {
25
26using namespace array;
27using namespace component;
28using namespace felt;
29using namespace polymorphic;
30using namespace string;
31
34template <typename Derived, typename ResultType> struct LLZKTypeSwitch {
35 inline ResultType match(Type type) {
36 return llvm::TypeSwitch<Type, ResultType>(type)
37 .template Case<IndexType>([this](auto t) {
38 return static_cast<Derived *>(this)->caseIndex(t);
39 })
40 .template Case<FeltType>([this](auto t) {
41 return static_cast<Derived *>(this)->caseFelt(t);
42 })
43 .template Case<StringType>([this](auto t) {
44 return static_cast<Derived *>(this)->caseString(t);
45 })
46 .template Case<TypeVarType>([this](auto t) {
47 return static_cast<Derived *>(this)->caseTypeVar(t);
48 })
49 .template Case<ArrayType>([this](auto t) {
50 return static_cast<Derived *>(this)->caseArray(t);
51 })
52 .template Case<StructType>([this](auto t) {
53 return static_cast<Derived *>(this)->caseStruct(t);
54 }).template Default([this](Type t) {
55 if (t.isSignlessInteger(1)) {
56 return static_cast<Derived *>(this)->caseBool(cast<IntegerType>(t));
57 } else {
58 return static_cast<Derived *>(this)->caseInvalid(t);
59 }
60 });
61 }
62};
63
64void BuildShortTypeString::appendSymName(StringRef str) {
65 if (str.empty()) {
66 ss << '?';
67 } else {
68 ss << '@' << str;
69 }
70}
71
72void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
73 appendSymName(sa.getRootReference().getValue());
74 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
75 ss << "::";
76 appendSymName(nestedRef.getValue());
77 }
78}
79
80BuildShortTypeString &BuildShortTypeString::append(Type type) {
81 size_t position = ret.size();
82
83 struct Impl : LLZKTypeSwitch<Impl, void> {
84 BuildShortTypeString &outer;
85 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
86
87 void caseInvalid(Type _) { outer.ss << "!INVALID"; }
88 void caseBool(IntegerType _) { outer.ss << 'b'; }
89 void caseIndex(IndexType _) { outer.ss << 'i'; }
90 void caseFelt(FeltType _) { outer.ss << 'f'; }
91 void caseString(StringType _) { outer.ss << 's'; }
92 void caseTypeVar(TypeVarType t) {
93 outer.ss << "!t<";
94 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
95 outer.ss << '>';
96 }
97 void caseArray(ArrayType t) {
98 outer.ss << "!a<";
99 outer.append(t.getElementType());
100 outer.ss << ':';
101 outer.append(t.getDimensionSizes());
102 outer.ss << '>';
103 }
104 void caseStruct(StructType t) {
105 outer.ss << "!s<";
106 outer.appendSymRef(t.getNameRef());
107 if (ArrayAttr params = t.getParams()) {
108 outer.ss << '_';
109 outer.append(params.getValue());
110 }
111 outer.ss << '>';
112 }
113 };
114 Impl(*this).match(type);
115
116 assert(
117 ret.find(PLACEHOLDER, position) == std::string::npos &&
118 "formatting a Type should not produce the 'PLACEHOLDER' char"
119 );
120 return *this;
121}
122
123BuildShortTypeString &BuildShortTypeString::append(Attribute a) {
124 // Special case for inserting the `PLACEHOLDER`
125 if (a == nullptr) {
126 ss << PLACEHOLDER;
127 return *this;
128 }
129
130 size_t position = ret.size();
131 // Adapted from AsmPrinter::Impl::printAttributeImpl()
132 if (llvm::isa<IntegerAttr>(a)) {
133 IntegerAttr ia = llvm::cast<IntegerAttr>(a);
134 Type ty = ia.getType();
135 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
136 ia.getValue().print(ss, !isUnsigned);
137 } else if (llvm::isa<SymbolRefAttr>(a)) {
138 appendSymRef(llvm::cast<SymbolRefAttr>(a));
139 } else if (llvm::isa<TypeAttr>(a)) {
140 append(llvm::cast<TypeAttr>(a).getValue());
141 } else if (llvm::isa<AffineMapAttr>(a)) {
142 ss << "!m<";
143 // Filter to remove spaces from the affine_map representation
144 filtered_raw_ostream fs(ss, [](char c) { return c == ' '; });
145 llvm::cast<AffineMapAttr>(a).getValue().print(fs);
146 fs.flush();
147 ss << '>';
148 } else if (llvm::isa<ArrayAttr>(a)) {
149 append(llvm::cast<ArrayAttr>(a).getValue());
150 } else {
151 // All valid/legal cases must be covered above
153 }
154 assert(
155 ret.find(PLACEHOLDER, position) == std::string::npos &&
156 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
157 );
158 return *this;
159}
160
161BuildShortTypeString &BuildShortTypeString::append(ArrayRef<Attribute> attrs) {
162 llvm::interleave(attrs, ss, [this](Attribute a) { append(a); }, "_");
163 return *this;
164}
165
166std::string BuildShortTypeString::from(const std::string &base, ArrayRef<Attribute> attrs) {
167 BuildShortTypeString bldr;
168
169 bldr.ret.reserve(base.size() + attrs.size()); // reserve minimum space required
170
171 // First handle replacements of PLACEHOLDER
172 auto END = attrs.end();
173 auto IT = attrs.begin();
174 {
175 size_t start = 0;
176 for (size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
177 // Append original up to the PLACEHOLDER
178 bldr.ret.append(base, start, pos - start);
179 // Append the formatted Attribute
180 assert(IT != END && "must have an Attribute for every 'PLACEHOLDER' char");
181 bldr.append(*IT++);
182 }
183 // Append remaining suffix of the original
184 bldr.ret.append(base, start, base.size() - start);
185 }
186
187 // Append any remaining Attributes
188 if (IT != END) {
189 bldr.ss << '_';
190 bldr.append(ArrayRef(IT, END));
191 }
192
193 return bldr.ret;
194}
195
196namespace {
197
198template <typename... Types> class TypeList {
199
201 template <typename StreamType> struct Appender {
202
203 // single
204 template <typename Ty> static inline void append(StreamType &stream) {
205 stream << '\'' << Ty::name << '\'';
206 }
207
208 // multiple
209 template <typename First, typename Second, typename... Rest>
210 static void append(StreamType &stream) {
211 append<First>(stream);
212 stream << ", ";
213 append<Second, Rest...>(stream);
214 }
215
216 // full list with wrapping brackets
217 static inline void append(StreamType &stream) {
218 stream << '[';
219 append<Types...>(stream);
220 stream << ']';
221 }
222 };
223
224public:
225 // Checks if the provided value is an instance of any of `Types`
226 template <typename T> static inline bool matches(const T &value) {
227 return llvm::isa_and_present<Types...>(value);
228 }
229
230 static void reportInvalid(EmitErrorFn emitError, const Twine &foundName, const char *aspect) {
231 InFlightDiagnostic diag = emitError().append(aspect, " must be one of ");
232 Appender<InFlightDiagnostic>::append(diag);
233 diag.append(" but found '", foundName, '\'').report();
234 }
235
236 static inline void reportInvalid(EmitErrorFn emitError, Attribute found, const char *aspect) {
237 if (emitError) {
238 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() : "nullptr", aspect);
239 }
240 }
241
242 // Returns a comma-separated list formatted string of the names of `Types`
243 static inline std::string getNames() {
244 return buildStringViaCallback(Appender<llvm::raw_string_ostream>::append);
245 }
246};
247
250template <class... Ts> struct make_unique {
251 using type = TypeList<Ts...>;
252};
253
254template <class... Ts> struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
255
256template <class U, class... Us, class... Ts>
257struct make_unique<TypeList<U, Us...>, Ts...>
258 : std::conditional_t<
259 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
260 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
261
262template <class... Ts> using TypeListUnion = typename make_unique<Ts...>::type;
263
264// Dimensions in the ArrayType must be one of the following:
265// - Integer constants
266// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
267// - AffineMap (for array created within a loop where size depends on loop variable)
268using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
269
270// Parameters in the StructType must be one of the following:
271// - Integer constants
272// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
273// - Type
274// - AffineMap (for array of non-homogeneous structs)
275using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
276
277class AllowedTypes {
278 struct ColumnCheckData {
279 SymbolTableCollection *symbolTable = nullptr;
280 Operation *op = nullptr;
281 };
282
283 bool no_felt : 1 = false;
284 bool no_string : 1 = false;
285 bool no_non_signal_struct : 1 = false;
286 bool no_signal_struct : 1 = false;
287 bool no_array : 1 = false;
288 bool no_var : 1 = false;
289 bool no_int : 1 = false;
290 bool no_struct_params : 1 = false;
291 bool must_be_column : 1 = false;
292
293 ColumnCheckData columnCheck;
294
298 bool validColumns(StructType s) {
299 if (!must_be_column) {
300 return true;
301 }
302 assert(columnCheck.symbolTable);
303 assert(columnCheck.op);
304 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
305 }
306
307public:
308 constexpr AllowedTypes &noFelt() {
309 no_felt = true;
310 return *this;
311 }
312
313 constexpr AllowedTypes &noString() {
314 no_string = true;
315 return *this;
316 }
317
318 constexpr AllowedTypes &noStruct() {
319 no_non_signal_struct = true;
320 no_signal_struct = true;
321 return *this;
322 }
323
324 constexpr AllowedTypes &noStructExceptSignal() {
325 no_non_signal_struct = true;
326 no_signal_struct = false;
327 return *this;
328 }
329
330 constexpr AllowedTypes &noArray() {
331 no_array = true;
332 return *this;
333 }
334
335 constexpr AllowedTypes &noVar() {
336 no_var = true;
337 return *this;
338 }
339
340 constexpr AllowedTypes &noInt() {
341 no_int = true;
342 return *this;
343 }
344
345 constexpr AllowedTypes &noStructParams(bool noStructParams = true) {
346 no_struct_params = noStructParams;
347 return *this;
348 }
349
350 constexpr AllowedTypes &onlyInt() {
351 no_int = false;
352 return noFelt().noString().noStruct().noArray().noVar();
353 }
354
355 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
356 must_be_column = true;
357 columnCheck.symbolTable = &symbolTable;
358 columnCheck.op = op;
359 return *this;
360 }
361
362 // This is the main check for allowed types.
363 bool isValidTypeImpl(Type type);
364
365 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr) {
366 // In LLZK, the number of array dimensions must always be known, i.e., `hasRank()==true`
367 if (dimensionSizes.empty()) {
368 if (emitError) {
369 emitError().append("array must have at least one dimension").report();
370 }
371 return false;
372 }
373 // Rather than immediately returning on failure, we check all dimensions and aggregate to
374 // provide as many errors are possible in a single verifier run.
375 bool success = true;
376 for (Attribute a : dimensionSizes) {
377 if (!ArrayDimensionTypes::matches(a)) {
378 ArrayDimensionTypes::reportInvalid(emitError, a, "Array dimension");
379 success = false;
380 } else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
381 TypeList<IntegerAttr>::reportInvalid(emitError, a, "Concrete array dimension");
382 success = false;
383 } else if (failed(verifyAffineMapAttrType(emitError, a))) {
384 success = false;
385 } else if (failed(verifyIntAttrType(emitError, a))) {
386 success = false;
387 }
388 }
389 return success;
390 }
391
392 bool isValidArrayElemTypeImpl(Type type) {
393 // ArrayType element can be any valid type sans ArrayType itself.
394 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
395 }
396
397 bool isValidArrayTypeImpl(
398 Type elementType, ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr
399 ) {
400 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
401 return false;
402 }
403
404 // Ensure array element type is valid
405 if (!isValidArrayElemTypeImpl(elementType)) {
406 if (emitError) {
407 // Print proper message if `elementType` is not a valid LLZK type or
408 // if it's simply not the right kind of type for an array element.
409 if (succeeded(checkValidType(emitError, elementType))) {
410 emitError()
411 .append(
412 '\'', ArrayType::name, "' element type cannot be '",
413 elementType.getAbstractType().getName(), '\''
414 )
415 .report();
416 }
417 }
418 return false;
419 }
420 return true;
421 }
422
423 bool isValidArrayTypeImpl(Type type) {
424 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
425 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
426 }
427 return false;
428 }
429
430 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter (if any) except
431 // for the `no_struct_params` flag which requires that `params` is null or empty.
432 bool areValidStructTypeParams(ArrayAttr params, EmitErrorFn emitError = nullptr) {
433 if (isNullOrEmpty(params)) {
434 return true;
435 }
436 if (no_struct_params) {
437 return false;
438 }
439 bool success = true;
440 for (Attribute p : params) {
441 if (!StructParamTypes::matches(p)) {
442 StructParamTypes::reportInvalid(emitError, p, "Struct parameter");
443 success = false;
444 } else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
445 if (!isValidTypeImpl(tyAttr.getValue())) {
446 if (emitError) {
447 emitError().append("expected a valid LLZK type but found ", tyAttr.getValue()).report();
448 }
449 success = false;
450 }
451 } else if (no_var && !llvm::isa<IntegerAttr>(p)) {
452 TypeList<IntegerAttr>::reportInvalid(emitError, p, "Concrete struct parameter");
453 success = false;
454 } else if (failed(verifyAffineMapAttrType(emitError, p))) {
455 success = false;
456 } else if (failed(verifyIntAttrType(emitError, p))) {
457 success = false;
458 }
459 }
460
461 return success;
462 }
463};
464
465bool AllowedTypes::isValidTypeImpl(Type type) {
466 assert(
467 !(no_int && no_felt && no_string && no_var && no_non_signal_struct && no_signal_struct &&
468 no_array) &&
469 "All types have been deactivated"
470 );
471 struct Impl : LLZKTypeSwitch<Impl, bool> {
472 AllowedTypes &outer;
473 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
474
475 bool caseBool(IntegerType t) { return !outer.no_int && t.isSignlessInteger(1); }
476 bool caseIndex(IndexType _) { return !outer.no_int; }
477 bool caseFelt(FeltType _) { return !outer.no_felt; }
478 bool caseString(StringType _) { return !outer.no_string; }
479 bool caseTypeVar(TypeVarType _) { return !outer.no_var; }
480 bool caseArray(ArrayType t) {
481 return !outer.no_array &&
482 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
483 }
484 bool caseStruct(StructType t) {
485 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter.
486 if ((outer.no_signal_struct && outer.no_non_signal_struct) || !outer.validColumns(t)) {
487 return false;
488 }
489 return (!outer.no_signal_struct && isSignalType(t)) ||
490 (!outer.no_non_signal_struct && outer.areValidStructTypeParams(t.getParams()));
491 }
492 bool caseInvalid(Type _) { return false; }
493 };
494 return Impl(*this).match(type);
495}
496
497} // namespace
498
499bool isValidType(Type type) { return AllowedTypes().isValidTypeImpl(type); }
500
501bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op) {
502 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
503}
504
505bool isValidGlobalType(Type type) { return AllowedTypes().noVar().isValidTypeImpl(type); }
506
507bool isValidEmitEqType(Type type) {
508 return AllowedTypes().noString().noStructExceptSignal().isValidTypeImpl(type);
509}
510
511// Allowed types must align with StructParamTypes (defined below)
512bool isValidConstReadType(Type type) {
513 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
514}
515
516bool isValidArrayElemType(Type type) { return AllowedTypes().isValidArrayElemTypeImpl(type); }
517
518bool isValidArrayType(Type type) { return AllowedTypes().isValidArrayTypeImpl(type); }
519
520bool isConcreteType(Type type, bool allowStructParams) {
521 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
522}
523
524bool isSignalType(Type type) {
525 if (auto structParamTy = llvm::dyn_cast<StructType>(type)) {
526 return isSignalType(structParamTy);
527 }
528 return false;
529}
530
532 // Only check the leaf part of the reference (i.e., just the struct name itself) to allow cases
533 // where the `COMPONENT_NAME_SIGNAL` struct may be placed within some nesting of modules, as
534 // happens when it's imported via an IncludeOp.
535 return sType.getNameRef().getLeafReference() == COMPONENT_NAME_SIGNAL;
536}
537
538bool hasAffineMapAttr(Type type) {
539 bool encountered = false;
540 type.walk([&](AffineMapAttr a) {
541 encountered = true;
542 return WalkResult::interrupt();
543 });
544 return encountered;
545}
546
547bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); }
548
549uint64_t computeEmitEqCardinality(Type type) {
550 struct Impl : LLZKTypeSwitch<Impl, uint64_t> {
551 uint64_t caseBool(IntegerType _) { return 1; }
552 uint64_t caseIndex(IndexType _) { return 1; }
553 uint64_t caseFelt(FeltType _) { return 1; }
554 uint64_t caseArray(ArrayType t) {
555 int64_t n = t.getNumElements();
556 assert(n >= 0);
557 return static_cast<uint64_t>(n);
558 }
559 uint64_t caseStruct(StructType t) {
560 if (isSignalType(t)) {
561 return 1;
562 }
563 llvm_unreachable("not a valid EmitEq type");
564 }
565 uint64_t caseString(StringType _) { llvm_unreachable("not a valid EmitEq type"); }
566 uint64_t caseTypeVar(TypeVarType _) { llvm_unreachable("tvar has unknown cardinality"); }
567 uint64_t caseInvalid(Type _) { llvm_unreachable("not a valid LLZK type"); }
568 };
569 return Impl().match(type);
570}
571
572namespace {
573
582using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
583
584struct UnifierImpl {
585 ArrayRef<StringRef> rhsRevPrefix;
586 UnificationMap *unifications;
587 AffineInstantiations *affineToIntTracker;
588 // This optional function can be used to provide an exception to the standard unification
589 // rules and return a true/success result when it otherwise may not.
590 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
591
592 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
593 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
594 overrideSuccess(nullptr) {}
595
596 bool typeParamsUnify(
597 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
598 bool unifyDynamicSize = false
599 ) {
600 auto pred = [this, unifyDynamicSize](auto lhsAttr, auto rhsAttr) {
601 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
602 };
603 return (lhsParams.size() == rhsParams.size()) &&
604 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
605 }
606
607 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
608 this->affineToIntTracker = tracker;
609 return *this;
610 }
611
612 UnifierImpl &withOverrides(llvm::function_ref<bool(Type oldTy, Type newTy)> overrides) {
613 this->overrideSuccess = overrides;
614 return *this;
615 }
616
619 bool typeParamsUnify(
620 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, bool unifyDynamicSize = false
621 ) {
622 if (lhsParams && rhsParams) {
623 return typeParamsUnify(lhsParams.getValue(), rhsParams.getValue(), unifyDynamicSize);
624 }
625 // When one or the other is null, they're only equivalent if both are null
626 return !lhsParams && !rhsParams;
627 }
628
629 bool arrayTypesUnify(ArrayType lhs, ArrayType rhs) {
630 // Check if the element types of the two arrays can unify
631 if (!typesUnify(lhs.getElementType(), rhs.getElementType())) {
632 return false;
633 }
634 // Check if the dimension size attributes unify between the LHS and RHS
635 return typeParamsUnify(
636 lhs.getDimensionSizes(), rhs.getDimensionSizes(), /*unifyDynamicSize=*/true
637 );
638 }
639
640 bool structTypesUnify(StructType lhs, StructType rhs) {
641 // Check if it references the same StructDefOp, considering the additional RHS path prefix.
642 SmallVector<StringRef> rhsNames = getNames(rhs.getNameRef());
643 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
644 if (rhsNames != getNames(lhs.getNameRef())) {
645 return false;
646 }
647 // Check if the parameters unify between the LHS and RHS
648 return typeParamsUnify(lhs.getParams(), rhs.getParams());
649 }
650
651 bool typesUnify(Type lhs, Type rhs) {
652 if (lhs == rhs) {
653 return true;
654 }
655 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
656 return true;
657 }
658 // A type variable can be any type, thus it unifies with anything.
659 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
660 track(Side::LHS, lhsTvar.getNameRef(), rhs);
661 return true;
662 }
663 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
664 track(Side::RHS, rhsTvar.getNameRef(), lhs);
665 return true;
666 }
667 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
668 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
669 }
670 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
671 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
672 }
673 return false;
674 }
675
676private:
677 template <typename Tracker, typename Key, typename Val>
678 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
679 auto key = std::make_pair(keyHead, side);
680 auto it = tracker.find(key);
681 if (it == tracker.end()) {
682 tracker.try_emplace(key, val);
683 } else if (it->getSecond() != val) {
684 it->second = nullptr;
685 }
686 }
687
688 void track(Side side, SymbolRefAttr symRef, Type ty) {
689 if (unifications) {
690 Attribute attr;
691 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
692 // If 'ty' is TypeVarType<@S>, just map to @S directly.
693 attr = tvar.getNameRef();
694 } else {
695 // Otherwise wrap as a TypeAttr.
696 attr = TypeAttr::get(ty);
697 }
698 assert(symRef);
699 assert(attr);
700 track(*unifications, side, symRef, attr);
701 }
702 }
703
704 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
705 if (unifications) {
706 // If 'attr' is TypeAttr<TypeVarType<@S>>, just map to @S directly.
707 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
708 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
709 attr = tvar.getNameRef();
710 }
711 }
712 assert(symRef);
713 assert(attr);
714 // If 'attr' is a SymbolRefAttr, map in both directions for the correctness of
715 // `isMoreConcreteUnification()` which relies on RHS check while other external
716 // checks on the UnificationMap may do LHS checks, and in the case of both being
717 // SymbolRefAttr, unification in either direction is possible.
718 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
719 track(*unifications, reverse(side), otherSymAttr, symRef);
720 }
721 track(*unifications, side, symRef, attr);
722 }
723 }
724
725 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
726 if (affineToIntTracker) {
727 assert(affineAttr);
728 assert(intAttr);
729 assert(!isDynamic(intAttr));
730 track(*affineToIntTracker, side, affineAttr, intAttr);
731 }
732 }
733
734 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr, bool unifyDynamicSize = false) {
737 // Straightforward equality check.
738 if (lhsAttr == rhsAttr) {
739 return true;
740 }
741 // AffineMapAttr can unify with IntegerAttr (other than kDynamic) because struct parameter
742 // instantiation will result in conversion of AffineMapAttr to IntegerAttr.
743 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
744 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
745 if (!isDynamic(rhsInt)) {
746 track(Side::LHS, lhsAffine, rhsInt);
747 return true;
748 }
749 }
750 }
751 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
752 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
753 if (!isDynamic(lhsInt)) {
754 track(Side::RHS, rhsAffine, lhsInt);
755 return true;
756 }
757 }
758 }
759 // If either side is a SymbolRefAttr, assume they unify because either flattening or a pass with
760 // a more involved value analysis is required to check if they are actually the same value.
761 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
762 track(Side::LHS, lhsSymRef, rhsAttr);
763 return true;
764 }
765 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
766 track(Side::RHS, rhsSymRef, lhsAttr);
767 return true;
768 }
769 // If either side is ShapedType::kDynamic then, similarly to Symbols, assume they unify.
770 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
771 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
772 if (isDynamic(intAttr)) {
773 return intAttr;
774 }
775 }
776 return nullptr;
777 };
778 auto isa_const = [](Attribute attr) {
779 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
780 };
781 if (auto lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
782 if (isa_const(rhsAttr)) {
783 return true;
784 }
785 }
786 if (auto rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
787 if (isa_const(lhsAttr)) {
788 return true;
789 }
790 }
791 // If both are type refs, check for unification of the types.
792 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
793 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
794 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
795 }
796 }
797 // Otherwise, they do not unify.
798 return false;
799 }
800};
801
802} // namespace
803
805 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
806 UnificationMap *unifications
807) {
808 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
809}
810
814 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, UnificationMap *unifications
815) {
816 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
817}
818
820 ArrayType lhs, ArrayType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
821) {
822 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
823}
824
826 StructType lhs, StructType rhs, ArrayRef<StringRef> rhsReversePrefix,
827 UnificationMap *unifications
828) {
829 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
830}
831
833 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
834) {
835 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
836}
837
839 Type oldTy, Type newTy, llvm::function_ref<bool(Type oldTy, Type newTy)> knownOldToNew
840) {
841 UnificationMap unifications;
842 AffineInstantiations affineInstantiations;
843 // Run type unification with the addition that affine map can become integer in the new type.
844 if (!UnifierImpl(&unifications)
845 .trackAffineToInt(&affineInstantiations)
846 .withOverrides(knownOldToNew)
847 .typesUnify(oldTy, newTy)) {
848 return false;
849 }
850
851 // If either map contains RHS-keyed mappings then the old type is "more concrete" than the new.
852 // In the UnificationMap, a RHS key would indicate that the new type contains a SymbolRef (i.e.
853 // the "least concrete" attribute kind) where the old type contained any other attribute. In the
854 // AffineInstantiations map, a RHS key would indicate that the new type contains an AffineMapAttr
855 // where the old type contains an IntegerAttr.
856 auto entryIsRHS = [](const auto &entry) { return entry.first.second == Side::RHS; };
857 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
858}
859
860IntegerAttr forceIntType(IntegerAttr attr) {
861 if (AllowedTypes().onlyInt().isValidTypeImpl(attr.getType())) {
862 return attr;
863 }
864 return IntegerAttr::get(IndexType::get(attr.getContext()), attr.getValue());
865}
866
867Attribute forceIntAttrType(Attribute attr) {
868 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
869 return forceIntType(intAttr);
870 }
871 return attr;
872}
873
874SmallVector<Attribute> forceIntAttrTypes(ArrayRef<Attribute> attrList) {
875 return llvm::map_to_vector(attrList, forceIntAttrType);
876}
877
878LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in) {
879 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
880 Type attrTy = intAttr.getType();
881 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
882 if (emitError) {
883 emitError()
884 .append("IntegerAttr must have type 'index' or 'i1' but found '", attrTy, '\'')
885 .report();
886 }
887 return failure();
888 }
889 }
890 return success();
891}
892
893LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in) {
894 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
895 AffineMap map = affineAttr.getValue();
896 if (map.getNumResults() != 1) {
897 if (emitError) {
898 emitError()
899 .append(
900 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
901 " results"
902 )
903 .report();
904 }
905 return failure();
906 }
907 }
908 return success();
909}
910
911LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params) {
912 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
913}
914
915LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes) {
916 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
917}
918
919LogicalResult
920verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes) {
921 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
922}
923
924void assertValidAttrForParamOfType(Attribute attr) {
925 // Must be the union of valid attribute types within ArrayType, StructType, and TypeVarType.
926 using TypeVarAttrs = TypeList<SymbolRefAttr>; // per ODS spec of TypeVarType
927 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
928 llvm::report_fatal_error(
929 "Legal type parameters are inconsistent. Encountered " +
930 attr.getAbstractAttribute().getName()
931 );
932 }
933}
934
935} // namespace llzk
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
Definition TypeHelper.h:36
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::mlir::SymbolRefAttr getNameRef() const
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
llvm::function_ref< mlir::InFlightDiagnostic()> EmitErrorFn
Definition ErrorHelper.h:21
bool isValidGlobalType(Type type)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
SmallVector< Attribute > forceIntAttrTypes(ArrayRef< Attribute > attrList)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:143
bool isSignalType(Type type)
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
Definition Constants.h:16
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
Attribute forceIntAttrType(Attribute attr)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
IntegerAttr forceIntType(IntegerAttr attr)
bool hasAffineMapAttr(Type type)
int64_t fromAPInt(llvm::APInt i)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)