LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
TypeHelper.cpp
Go to the documentation of this file.
1//===-- TypeHelper.cpp ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
16#include "llzk/Util/Debug.h"
20
21#include <llvm/ADT/TypeSwitch.h>
22
23using namespace mlir;
24
25namespace llzk {
26
27using namespace array;
28using namespace component;
29using namespace felt;
30using namespace polymorphic;
31using namespace string;
32
35template <typename Derived, typename ResultType> struct LLZKTypeSwitch {
36 inline ResultType match(Type type) {
37 return llvm::TypeSwitch<Type, ResultType>(type)
38 .template Case<IndexType>([this](auto t) {
39 return static_cast<Derived *>(this)->caseIndex(t);
40 })
41 .template Case<FeltType>([this](auto t) {
42 return static_cast<Derived *>(this)->caseFelt(t);
43 })
44 .template Case<StringType>([this](auto t) {
45 return static_cast<Derived *>(this)->caseString(t);
46 })
47 .template Case<TypeVarType>([this](auto t) {
48 return static_cast<Derived *>(this)->caseTypeVar(t);
49 })
50 .template Case<ArrayType>([this](auto t) {
51 return static_cast<Derived *>(this)->caseArray(t);
52 })
53 .template Case<StructType>([this](auto t) {
54 return static_cast<Derived *>(this)->caseStruct(t);
55 }).Default([this](Type t) {
56 if (t.isSignlessInteger(1)) {
57 return static_cast<Derived *>(this)->caseBool(cast<IntegerType>(t));
58 } else {
59 return static_cast<Derived *>(this)->caseInvalid(t);
60 }
61 });
62 }
63};
64
65void BuildShortTypeString::appendSymName(StringRef str) {
66 if (str.empty()) {
67 ss << '?';
68 } else {
69 ss << '@' << str;
70 }
71}
72
73void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
74 appendSymName(sa.getRootReference().getValue());
75 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
76 ss << "::";
77 appendSymName(nestedRef.getValue());
78 }
79}
80
81BuildShortTypeString &BuildShortTypeString::append(Type type) {
82 size_t position = ret.size();
83
84 struct Impl : LLZKTypeSwitch<Impl, void> {
85 BuildShortTypeString &outer;
86 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
87
88 void caseInvalid(Type) { outer.ss << "!INVALID"; }
89 void caseBool(IntegerType) { outer.ss << 'b'; }
90 void caseIndex(IndexType) { outer.ss << 'i'; }
91 void caseFelt(FeltType) { outer.ss << 'f'; }
92 void caseString(StringType) { outer.ss << 's'; }
93 void caseTypeVar(TypeVarType t) {
94 outer.ss << "!t<";
95 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
96 outer.ss << '>';
97 }
98 void caseArray(ArrayType t) {
99 outer.ss << "!a<";
100 outer.append(t.getElementType());
101 outer.ss << ':';
102 outer.append(t.getDimensionSizes());
103 outer.ss << '>';
104 }
105 void caseStruct(StructType t) {
106 outer.ss << "!s<";
107 outer.appendSymRef(t.getNameRef());
108 if (ArrayAttr params = t.getParams()) {
109 outer.ss << '_';
110 outer.append(params.getValue());
111 }
112 outer.ss << '>';
113 }
114 };
115 Impl(*this).match(type);
116
117 assert(
118 ret.find(PLACEHOLDER, position) == std::string::npos &&
119 "formatting a Type should not produce the 'PLACEHOLDER' char"
120 );
121 return *this;
122}
123
124BuildShortTypeString &BuildShortTypeString::append(Attribute a) {
125 // Special case for inserting the `PLACEHOLDER`
126 if (a == nullptr) {
127 ss << PLACEHOLDER;
128 return *this;
129 }
130
131 size_t position = ret.size();
132 // Adapted from AsmPrinter::Impl::printAttributeImpl()
133 if (auto ia = llvm::dyn_cast<IntegerAttr>(a)) {
134 Type ty = ia.getType();
135 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
136 ia.getValue().print(ss, !isUnsigned);
137 } else if (auto sra = llvm::dyn_cast<SymbolRefAttr>(a)) {
138 appendSymRef(sra);
139 } else if (auto ta = llvm::dyn_cast<TypeAttr>(a)) {
140 append(ta.getValue());
141 } else if (auto ama = llvm::dyn_cast<AffineMapAttr>(a)) {
142 ss << "!m<";
143 // Filter to remove spaces from the affine_map representation
144 filtered_raw_ostream fs(ss, [](char c) { return c == ' '; });
145 ama.getValue().print(fs);
146 fs.flush();
147 ss << '>';
148 } else if (auto aa = llvm::dyn_cast<ArrayAttr>(a)) {
149 append(aa.getValue());
150 } else {
151 // All valid/legal cases must be covered above
153 }
154 assert(
155 ret.find(PLACEHOLDER, position) == std::string::npos &&
156 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
157 );
158 return *this;
159}
160
161BuildShortTypeString &BuildShortTypeString::append(ArrayRef<Attribute> attrs) {
162 llvm::interleave(attrs, ss, [this](Attribute a) { append(a); }, "_");
163 return *this;
164}
165
166std::string BuildShortTypeString::from(const std::string &base, ArrayRef<Attribute> attrs) {
167 BuildShortTypeString bldr;
168
169 bldr.ret.reserve(base.size() + attrs.size()); // reserve minimum space required
170
171 // First handle replacements of PLACEHOLDER
172 auto END = attrs.end();
173 auto IT = attrs.begin();
174 {
175 size_t start = 0;
176 for (size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
177 // Append original up to the PLACEHOLDER
178 bldr.ret.append(base, start, pos - start);
179 // Append the formatted Attribute
180 assert(IT != END && "must have an Attribute for every 'PLACEHOLDER' char");
181 bldr.append(*IT++);
182 }
183 // Append remaining suffix of the original
184 bldr.ret.append(base, start, base.size() - start);
185 }
186
187 // Append any remaining Attributes
188 if (IT != END) {
189 bldr.ss << '_';
190 bldr.append(ArrayRef(IT, END));
191 }
192
193 return bldr.ret;
194}
195
196namespace {
197
198template <typename... Types> class TypeList {
199
201 template <typename StreamType> struct Appender {
202
203 // single
204 template <typename Ty> static inline void append(StreamType &stream) {
205 stream << '\'' << Ty::name << '\'';
206 }
207
208 // multiple
209 template <typename First, typename Second, typename... Rest>
210 static void append(StreamType &stream) {
211 append<First>(stream);
212 stream << ", ";
213 append<Second, Rest...>(stream);
214 }
215
216 // full list with wrapping brackets
217 static inline void append(StreamType &stream) {
218 stream << '[';
219 append<Types...>(stream);
220 stream << ']';
221 }
222 };
223
224public:
225 // Checks if the provided value is an instance of any of `Types`
226 template <typename T> static inline bool matches(const T &value) {
227 return llvm::isa_and_present<Types...>(value);
228 }
229
230 static void reportInvalid(EmitErrorFn emitError, const Twine &foundName, const char *aspect) {
231 InFlightDiagnosticWrapper diag = emitError().append(aspect, " must be one of ");
232 Appender<InFlightDiagnosticWrapper>::append(diag);
233 diag.append(" but found '", foundName, '\'').report();
234 }
235
236 static inline void reportInvalid(EmitErrorFn emitError, Attribute found, const char *aspect) {
237 if (emitError) {
238 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() : "nullptr", aspect);
239 }
240 }
241
242 // Returns a comma-separated list formatted string of the names of `Types`
243 static inline std::string getNames() {
244 return buildStringViaCallback(Appender<llvm::raw_string_ostream>::append);
245 }
246};
247
250template <class... Ts> struct make_unique {
251 using type = TypeList<Ts...>;
252};
253
254template <class... Ts> struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
255
256template <class U, class... Us, class... Ts>
257struct make_unique<TypeList<U, Us...>, Ts...>
258 : std::conditional_t<
259 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
260 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
261
262template <class... Ts> using TypeListUnion = typename make_unique<Ts...>::type;
263
264// Dimensions in the ArrayType must be one of the following:
265// - Integer constants
266// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
267// - AffineMap (for array created within a loop where size depends on loop variable)
268using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
269
270// Parameters in the StructType must be one of the following:
271// - Integer constants
272// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
273// - Type
274// - AffineMap (for array of non-homogeneous structs)
275using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
276
277class AllowedTypes {
278 struct ColumnCheckData {
279 SymbolTableCollection *symbolTable = nullptr;
280 Operation *op = nullptr;
281 };
282
283 bool no_felt : 1 = false;
284 bool no_string : 1 = false;
285 bool no_non_signal_struct : 1 = false;
286 bool no_signal_struct : 1 = false;
287 bool no_array : 1 = false;
288 bool no_var : 1 = false;
289 bool no_int : 1 = false;
290 bool no_struct_params : 1 = false;
291 bool must_be_column : 1 = false;
292
293 ColumnCheckData columnCheck;
294
298 bool validColumns(StructType s) {
299 if (!must_be_column) {
300 return true;
301 }
302 assert(columnCheck.symbolTable);
303 assert(columnCheck.op);
304 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
305 }
306
307public:
308 constexpr AllowedTypes &noFelt() {
309 no_felt = true;
310 return *this;
311 }
312
313 constexpr AllowedTypes &noString() {
314 no_string = true;
315 return *this;
316 }
317
318 constexpr AllowedTypes &noStruct() {
319 no_non_signal_struct = true;
320 no_signal_struct = true;
321 return *this;
322 }
323
324 constexpr AllowedTypes &noStructExceptSignal() {
325 no_non_signal_struct = true;
326 no_signal_struct = false;
327 return *this;
328 }
329
330 constexpr AllowedTypes &noArray() {
331 no_array = true;
332 return *this;
333 }
334
335 constexpr AllowedTypes &noVar() {
336 no_var = true;
337 return *this;
338 }
339
340 constexpr AllowedTypes &noInt() {
341 no_int = true;
342 return *this;
343 }
344
345 constexpr AllowedTypes &noStructParams(bool noStructParams = true) {
346 no_struct_params = noStructParams;
347 return *this;
348 }
349
350 constexpr AllowedTypes &onlyInt() {
351 no_int = false;
352 return noFelt().noString().noStruct().noArray().noVar();
353 }
354
355 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
356 must_be_column = true;
357 columnCheck.symbolTable = &symbolTable;
358 columnCheck.op = op;
359 return *this;
360 }
361
362 // This is the main check for allowed types.
363 bool isValidTypeImpl(Type type);
364
365 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr) {
366 // In LLZK, the number of array dimensions must always be known, i.e., `hasRank()==true`
367 if (dimensionSizes.empty()) {
368 if (emitError) {
369 emitError().append("array must have at least one dimension").report();
370 }
371 return false;
372 }
373 // Rather than immediately returning on failure, we check all dimensions and aggregate to
374 // provide as many errors are possible in a single verifier run.
375 bool success = true;
376 for (Attribute a : dimensionSizes) {
377 if (!ArrayDimensionTypes::matches(a)) {
378 ArrayDimensionTypes::reportInvalid(emitError, a, "Array dimension");
379 success = false;
380 } else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
381 TypeList<IntegerAttr>::reportInvalid(emitError, a, "Concrete array dimension");
382 success = false;
383 } else if (failed(verifyAffineMapAttrType(emitError, a))) {
384 success = false;
385 } else if (failed(verifyIntAttrType(emitError, a))) {
386 success = false;
387 }
388 }
389 return success;
390 }
391
392 bool isValidArrayElemTypeImpl(Type type) {
393 // ArrayType element can be any valid type sans ArrayType itself.
394 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
395 }
396
397 bool isValidArrayTypeImpl(
398 Type elementType, ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr
399 ) {
400 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
401 return false;
402 }
403
404 // Ensure array element type is valid
405 if (!isValidArrayElemTypeImpl(elementType)) {
406 if (emitError) {
407 // Print proper message if `elementType` is not a valid LLZK type or
408 // if it's simply not the right kind of type for an array element.
409 if (succeeded(checkValidType(emitError, elementType))) {
410 emitError()
411 .append(
412 '\'', ArrayType::name, "' element type cannot be '",
413 elementType.getAbstractType().getName(), '\''
414 )
415 .report();
416 }
417 }
418 return false;
419 }
420 return true;
421 }
422
423 bool isValidArrayTypeImpl(Type type) {
424 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
425 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
426 }
427 return false;
428 }
429
430 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter (if any) except
431 // for the `no_struct_params` flag which requires that `params` is null or empty.
432 bool areValidStructTypeParams(ArrayAttr params, EmitErrorFn emitError = nullptr) {
433 if (isNullOrEmpty(params)) {
434 return true;
435 }
436 if (no_struct_params) {
437 return false;
438 }
439 bool success = true;
440 for (Attribute p : params) {
441 if (!StructParamTypes::matches(p)) {
442 StructParamTypes::reportInvalid(emitError, p, "Struct parameter");
443 success = false;
444 } else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
445 if (!isValidTypeImpl(tyAttr.getValue())) {
446 if (emitError) {
447 emitError().append("expected a valid LLZK type but found ", tyAttr.getValue()).report();
448 }
449 success = false;
450 }
451 } else if (no_var && !llvm::isa<IntegerAttr>(p)) {
452 TypeList<IntegerAttr>::reportInvalid(emitError, p, "Concrete struct parameter");
453 success = false;
454 } else if (failed(verifyAffineMapAttrType(emitError, p))) {
455 success = false;
456 } else if (failed(verifyIntAttrType(emitError, p))) {
457 success = false;
458 }
459 }
460
461 return success;
462 }
463};
464
465bool AllowedTypes::isValidTypeImpl(Type type) {
466 assert(
467 !(no_int && no_felt && no_string && no_var && no_non_signal_struct && no_signal_struct &&
468 no_array) &&
469 "All types have been deactivated"
470 );
471 struct Impl : LLZKTypeSwitch<Impl, bool> {
472 AllowedTypes &outer;
473 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
474
475 bool caseBool(IntegerType t) { return !outer.no_int && t.isSignlessInteger(1); }
476 bool caseIndex(IndexType) { return !outer.no_int; }
477 bool caseFelt(FeltType) { return !outer.no_felt; }
478 bool caseString(StringType) { return !outer.no_string; }
479 bool caseTypeVar(TypeVarType) { return !outer.no_var; }
480 bool caseArray(ArrayType t) {
481 return !outer.no_array &&
482 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
483 }
484 bool caseStruct(StructType t) {
485 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter.
486 if ((outer.no_signal_struct && outer.no_non_signal_struct) || !outer.validColumns(t)) {
487 return false;
488 }
489 return (!outer.no_signal_struct && isSignalType(t)) ||
490 (!outer.no_non_signal_struct && outer.areValidStructTypeParams(t.getParams()));
491 }
492 bool caseInvalid(Type _) { return false; }
493 };
494 return Impl(*this).match(type);
495}
496
497} // namespace
498
499bool isValidType(Type type) { return AllowedTypes().isValidTypeImpl(type); }
500
501bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op) {
502 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
503}
504
505bool isValidGlobalType(Type type) { return AllowedTypes().noVar().isValidTypeImpl(type); }
506
507bool isValidEmitEqType(Type type) {
508 return AllowedTypes().noString().noStructExceptSignal().isValidTypeImpl(type);
509}
510
511// Allowed types must align with StructParamTypes (defined below)
512bool isValidConstReadType(Type type) {
513 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
514}
515
516bool isValidArrayElemType(Type type) { return AllowedTypes().isValidArrayElemTypeImpl(type); }
517
518bool isValidArrayType(Type type) { return AllowedTypes().isValidArrayTypeImpl(type); }
519
520bool isConcreteType(Type type, bool allowStructParams) {
521 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
522}
523
524bool isSignalType(Type type) {
525 if (auto structParamTy = llvm::dyn_cast<StructType>(type)) {
526 return isSignalType(structParamTy);
527 }
528 return false;
529}
530
532 // Only check the leaf part of the reference (i.e., just the struct name itself) to allow cases
533 // where the `COMPONENT_NAME_SIGNAL` struct may be placed within some nesting of modules, as
534 // happens when it's imported via an IncludeOp.
535 return sType.getNameRef().getLeafReference() == COMPONENT_NAME_SIGNAL;
536}
537
538bool hasAffineMapAttr(Type type) {
539 bool encountered = false;
540 type.walk([&](AffineMapAttr a) {
541 encountered = true;
542 return WalkResult::interrupt();
543 });
544 return encountered;
545}
546
547bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); }
548
549uint64_t computeEmitEqCardinality(Type type) {
550 struct Impl : LLZKTypeSwitch<Impl, uint64_t> {
551 uint64_t caseBool(IntegerType) { return 1; }
552 uint64_t caseIndex(IndexType) { return 1; }
553 uint64_t caseFelt(FeltType) { return 1; }
554 uint64_t caseArray(ArrayType t) {
555 int64_t n = t.getNumElements();
556 assert(n >= 0);
557 return static_cast<uint64_t>(n);
558 }
559 uint64_t caseStruct(StructType t) {
560 if (isSignalType(t)) {
561 return 1;
562 }
563 llvm_unreachable("not a valid EmitEq type");
564 }
565 uint64_t caseString(StringType) { llvm_unreachable("not a valid EmitEq type"); }
566 uint64_t caseTypeVar(TypeVarType) { llvm_unreachable("tvar has unknown cardinality"); }
567 uint64_t caseInvalid(Type) { llvm_unreachable("not a valid LLZK type"); }
568 };
569 return Impl().match(type);
570}
571
572namespace {
573
582using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
583
584struct UnifierImpl {
585 ArrayRef<StringRef> rhsRevPrefix;
586 UnificationMap *unifications;
587 AffineInstantiations *affineToIntTracker;
588 // This optional function can be used to provide an exception to the standard unification
589 // rules and return a true/success result when it otherwise may not.
590 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
591
592 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
593 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
594 overrideSuccess(nullptr) {}
595
596 bool typeParamsUnify(
597 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
598 bool unifyDynamicSize = false
599 ) {
600 auto pred = [this, unifyDynamicSize](auto lhsAttr, auto rhsAttr) {
601 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
602 };
603 return (lhsParams.size() == rhsParams.size()) &&
604 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
605 }
606
607 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
608 this->affineToIntTracker = tracker;
609 return *this;
610 }
611
612 UnifierImpl &withOverrides(llvm::function_ref<bool(Type oldTy, Type newTy)> overrides) {
613 this->overrideSuccess = overrides;
614 return *this;
615 }
616
619 bool typeParamsUnify(
620 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, bool unifyDynamicSize = false
621 ) {
622 if (lhsParams && rhsParams) {
623 return typeParamsUnify(lhsParams.getValue(), rhsParams.getValue(), unifyDynamicSize);
624 }
625 // When one or the other is null, they're only equivalent if both are null
626 return !lhsParams && !rhsParams;
627 }
628
629 bool arrayTypesUnify(ArrayType lhs, ArrayType rhs) {
630 // Check if the element types of the two arrays can unify
631 if (!typesUnify(lhs.getElementType(), rhs.getElementType())) {
632 return false;
633 }
634 // Check if the dimension size attributes unify between the LHS and RHS
635 return typeParamsUnify(
636 lhs.getDimensionSizes(), rhs.getDimensionSizes(), /*unifyDynamicSize=*/true
637 );
638 }
639
640 bool structTypesUnify(StructType lhs, StructType rhs) {
641 // Check if it references the same StructDefOp, considering the additional RHS path prefix.
642 SmallVector<StringRef> rhsNames = getNames(rhs.getNameRef());
643 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
644 if (rhsNames != getNames(lhs.getNameRef())) {
645 return false;
646 }
647 // Check if the parameters unify between the LHS and RHS
648 return typeParamsUnify(lhs.getParams(), rhs.getParams(), /*unifyDynamicSize=*/false);
649 }
650
651 bool typesUnify(Type lhs, Type rhs) {
652 if (lhs == rhs) {
653 return true;
654 }
655 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
656 return true;
657 }
658 // A type variable can be any type, thus it unifies with anything.
659 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
660 track(Side::LHS, lhsTvar.getNameRef(), rhs);
661 return true;
662 }
663 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
664 track(Side::RHS, rhsTvar.getNameRef(), lhs);
665 return true;
666 }
667 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
668 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
669 }
670 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
671 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
672 }
673 return false;
674 }
675
676private:
677 template <typename Tracker, typename Key, typename Val>
678 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
679 auto key = std::make_pair(keyHead, side);
680 auto it = tracker.find(key);
681 if (it == tracker.end()) {
682 tracker.try_emplace(key, val);
683 } else if (it->getSecond() != val) {
684 it->second = nullptr;
685 }
686 }
687
688 void track(Side side, SymbolRefAttr symRef, Type ty) {
689 if (unifications) {
690 Attribute attr;
691 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
692 // If 'ty' is TypeVarType<@S>, just map to @S directly.
693 attr = tvar.getNameRef();
694 } else {
695 // Otherwise wrap as a TypeAttr.
696 attr = TypeAttr::get(ty);
697 }
698 assert(symRef);
699 assert(attr);
700 track(*unifications, side, symRef, attr);
701 }
702 }
703
704 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
705 if (unifications) {
706 // If 'attr' is TypeAttr<TypeVarType<@S>>, just map to @S directly.
707 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
708 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
709 attr = tvar.getNameRef();
710 }
711 }
712 assert(symRef);
713 assert(attr);
714 // If 'attr' is a SymbolRefAttr, map in both directions for the correctness of
715 // `isMoreConcreteUnification()` which relies on RHS check while other external
716 // checks on the UnificationMap may do LHS checks, and in the case of both being
717 // SymbolRefAttr, unification in either direction is possible.
718 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
719 track(*unifications, reverse(side), otherSymAttr, symRef);
720 }
721 track(*unifications, side, symRef, attr);
722 }
723 }
724
725 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
726 if (affineToIntTracker) {
727 assert(affineAttr);
728 assert(intAttr);
729 assert(!isDynamic(intAttr));
730 track(*affineToIntTracker, side, affineAttr, intAttr);
731 }
732 }
733
734 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr, bool unifyDynamicSize = false) {
737 // Straightforward equality check.
738 if (lhsAttr == rhsAttr) {
739 return true;
740 }
741 // AffineMapAttr can unify with IntegerAttr (other than kDynamic) because struct parameter
742 // instantiation will result in conversion of AffineMapAttr to IntegerAttr.
743 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
744 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
745 if (!isDynamic(rhsInt)) {
746 track(Side::LHS, lhsAffine, rhsInt);
747 return true;
748 }
749 }
750 }
751 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
752 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
753 if (!isDynamic(lhsInt)) {
754 track(Side::RHS, rhsAffine, lhsInt);
755 return true;
756 }
757 }
758 }
759 // If either side is a SymbolRefAttr, assume they unify because either flattening or a pass with
760 // a more involved value analysis is required to check if they are actually the same value.
761 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
762 track(Side::LHS, lhsSymRef, rhsAttr);
763 return true;
764 }
765 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
766 track(Side::RHS, rhsSymRef, lhsAttr);
767 return true;
768 }
769 // If either side is ShapedType::kDynamic then, similarly to Symbols, assume they unify.
770 // NOTE: Dynamic array dimensions (i.e. '?') are allowed in LLZK but should generally be
771 // restricted to scenarios where it can be replaced with a concrete value during the flattening
772 // pass, such as a `unifiable_cast` where the other side of the cast has concrete dimensions or
773 // extern functions with varargs.
774 if (unifyDynamicSize) {
775 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
776 if (IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
777 if (isDynamic(intAttr)) {
778 return intAttr;
779 }
780 }
781 return nullptr;
782 };
783 auto is_const_like = [](Attribute attr) {
784 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
785 };
786 if (IntegerAttr lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
787 if (is_const_like(rhsAttr)) {
788 return true;
789 }
790 }
791 if (IntegerAttr rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
792 if (is_const_like(lhsAttr)) {
793 return true;
794 }
795 }
796 }
797 // If both are type refs, check for unification of the types.
798 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
799 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
800 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
801 }
802 }
803 // Otherwise, they do not unify.
804 return false;
805 }
806};
807
808} // namespace
809
811 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
812 UnificationMap *unifications
813) {
814 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
815}
816
820 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, UnificationMap *unifications
821) {
822 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
823}
824
826 ArrayType lhs, ArrayType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
827) {
828 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
829}
830
832 StructType lhs, StructType rhs, ArrayRef<StringRef> rhsReversePrefix,
833 UnificationMap *unifications
834) {
835 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
836}
837
839 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
840) {
841 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
842}
843
845 Type oldTy, Type newTy, llvm::function_ref<bool(Type oldTy, Type newTy)> knownOldToNew
846) {
847 UnificationMap unifications;
848 AffineInstantiations affineInstantiations;
849 // Run type unification with the addition that affine map can become integer in the new type.
850 if (!UnifierImpl(&unifications)
851 .trackAffineToInt(&affineInstantiations)
852 .withOverrides(knownOldToNew)
853 .typesUnify(oldTy, newTy)) {
854 return false;
855 }
856
857 // If either map contains RHS-keyed mappings then the old type is "more concrete" than the new.
858 // In the UnificationMap, a RHS key would indicate that the new type contains a SymbolRef (i.e.
859 // the "least concrete" attribute kind) where the old type contained any other attribute. In the
860 // AffineInstantiations map, a RHS key would indicate that the new type contains an AffineMapAttr
861 // where the old type contains an IntegerAttr.
862 auto entryIsRHS = [](const auto &entry) { return entry.first.second == Side::RHS; };
863 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
864}
865
866FailureOr<IntegerAttr> forceIntType(IntegerAttr attr, EmitErrorFn emitError) {
867 if (llvm::isa<IndexType>(attr.getType())) {
868 return attr;
869 }
870 // Ensure the APInt is the right bitwidth for IndexType or else
871 // IntegerAttr::verify(..) will report an error.
872 APInt value = attr.getValue();
873 auto compare = value.getBitWidth() <=> IndexType::kInternalStorageBitWidth;
874 if (compare < 0) {
875 value = value.zext(IndexType::kInternalStorageBitWidth);
876 } else if (compare > 0) {
877 return emitError().append("value is too large for `index` type: ", debug::toStringOne(value));
878 }
879 return IntegerAttr::get(IndexType::get(attr.getContext()), value);
880}
881
882FailureOr<Attribute> forceIntAttrType(Attribute attr, EmitErrorFn emitError) {
883 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
884 return forceIntType(intAttr, emitError);
885 }
886 return attr;
887}
888
889FailureOr<SmallVector<Attribute>>
890forceIntAttrTypes(ArrayRef<Attribute> attrList, EmitErrorFn emitError) {
891 SmallVector<Attribute> result;
892 for (Attribute attr : attrList) {
893 FailureOr<Attribute> forced = forceIntAttrType(attr, emitError);
894 if (failed(forced)) {
895 return failure();
896 }
897 result.push_back(*forced);
898 }
899 return result;
900}
901
902LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in) {
903 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
904 Type attrTy = intAttr.getType();
905 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
906 if (emitError) {
907 emitError()
908 .append("IntegerAttr must have type 'index' or 'i1' but found '", attrTy, '\'')
909 .report();
910 }
911 return failure();
912 }
913 }
914 return success();
915}
916
917LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in) {
918 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
919 AffineMap map = affineAttr.getValue();
920 if (map.getNumResults() != 1) {
921 if (emitError) {
922 emitError()
923 .append(
924 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
925 " results"
926 )
927 .report();
928 }
929 return failure();
930 }
931 }
932 return success();
933}
934
935LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params) {
936 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
937}
938
939LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes) {
940 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
941}
942
943LogicalResult
944verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes) {
945 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
946}
947
948void assertValidAttrForParamOfType(Attribute attr) {
949 // Must be the union of valid attribute types within ArrayType, StructType, and TypeVarType.
950 using TypeVarAttrs = TypeList<SymbolRefAttr>; // per ODS spec of TypeVarType
951 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
952 llvm::report_fatal_error(
953 "Legal type parameters are inconsistent. Encountered " +
954 attr.getAbstractAttribute().getName()
955 );
956 }
957}
958
959LogicalResult
960verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType) {
961 ArrayRef<Attribute> dimsFromArr = arrayType.getDimensionSizes();
962 size_t numArrDims = dimsFromArr.size();
963 ArrayRef<Attribute> dimsFromSubArr = subArrayType.getDimensionSizes();
964 size_t numSubArrDims = dimsFromSubArr.size();
965
966 if (numArrDims < numSubArrDims) {
967 return emitError().append(
968 "subarray type ", subArrayType, " has more dimensions than array type ", arrayType
969 );
970 }
971
972 size_t toDrop = numArrDims - numSubArrDims;
973 ArrayRef<Attribute> dimsFromArrReduced = dimsFromArr.drop_front(toDrop);
974
975 // Ensure dimension sizes are compatible (ignoring the indexed dimensions)
976 if (!typeParamsUnify(dimsFromArrReduced, dimsFromSubArr)) {
977 std::string message;
978 llvm::raw_string_ostream ss(message);
979 auto appendOne = [&ss](Attribute a) { appendWithoutType(ss, a); };
980 ss << "cannot unify array dimensions [";
981 llvm::interleaveComma(dimsFromArrReduced, ss, appendOne);
982 ss << "] with [";
983 llvm::interleaveComma(dimsFromSubArr, ss, appendOne);
984 ss << "]";
985 return emitError().append(message);
986 }
987
988 // Ensure element types of the arrays are compatible
989 if (!typesUnify(arrayType.getElementType(), subArrayType.getElementType())) {
990 return emitError().append(
991 "incorrect array element type; expected: ", arrayType.getElementType(),
992 ", found: ", subArrayType.getElementType()
993 );
994 }
995
996 return success();
997}
998
999LogicalResult
1000verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType) {
1001 if (auto subArrayType = llvm::dyn_cast<ArrayType>(subArrayOrElemType)) {
1002 return verifySubArrayType(emitError, arrayType, subArrayType);
1003 }
1004 if (!typesUnify(arrayType.getElementType(), subArrayOrElemType)) {
1005 return emitError().append(
1006 "incorrect array element type; expected: ", arrayType.getElementType(),
1007 ", found: ", subArrayOrElemType
1008 );
1009 }
1010
1011 return success();
1012}
1013
1014} // namespace llzk
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
Definition TypeHelper.h:36
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::mlir::SymbolRefAttr getNameRef() const
std::string toStringOne(const T &value)
Definition Debug.h:175
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:143
bool isSignalType(Type type)
int64_t fromAPInt(const llvm::APInt &i)
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
Definition Constants.h:16
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
void appendWithoutType(mlir::raw_ostream &os, mlir::Attribute a)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)