LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
TypeHelper.cpp
Go to the documentation of this file.
1//===-- TypeHelper.cpp ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
16#include "llzk/Util/Debug.h"
20
21#include <llvm/ADT/TypeSwitch.h>
22
23using namespace mlir;
24
25namespace llzk {
26
27using namespace array;
28using namespace component;
29using namespace felt;
30using namespace polymorphic;
31using namespace string;
32
35template <typename Derived, typename ResultType> struct LLZKTypeSwitch {
36 inline ResultType match(Type type) {
37 return llvm::TypeSwitch<Type, ResultType>(type)
38 .template Case<IndexType>([this](auto t) {
39 return static_cast<Derived *>(this)->caseIndex(t);
40 })
41 .template Case<FeltType>([this](auto t) {
42 return static_cast<Derived *>(this)->caseFelt(t);
43 })
44 .template Case<StringType>([this](auto t) {
45 return static_cast<Derived *>(this)->caseString(t);
46 })
47 .template Case<TypeVarType>([this](auto t) {
48 return static_cast<Derived *>(this)->caseTypeVar(t);
49 })
50 .template Case<ArrayType>([this](auto t) {
51 return static_cast<Derived *>(this)->caseArray(t);
52 })
53 .template Case<StructType>([this](auto t) {
54 return static_cast<Derived *>(this)->caseStruct(t);
55 }).Default([this](Type t) {
56 if (t.isSignlessInteger(1)) {
57 return static_cast<Derived *>(this)->caseBool(cast<IntegerType>(t));
58 } else {
59 return static_cast<Derived *>(this)->caseInvalid(t);
60 }
61 });
62 }
63};
64
65void BuildShortTypeString::appendSymName(StringRef str) {
66 if (str.empty()) {
67 ss << '?';
68 } else {
69 ss << '@' << str;
70 }
71}
72
73void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
74 appendSymName(sa.getRootReference().getValue());
75 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
76 ss << "::";
77 appendSymName(nestedRef.getValue());
78 }
79}
80
81BuildShortTypeString &BuildShortTypeString::append(Type type) {
82 size_t position = ret.size();
83 (void)position; // tell compiler it's intentionally unused in builds without assertions
84
85 struct Impl : LLZKTypeSwitch<Impl, void> {
86 BuildShortTypeString &outer;
87 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
88
89 void caseInvalid(Type) { outer.ss << "!INVALID"; }
90 void caseBool(IntegerType) { outer.ss << 'b'; }
91 void caseIndex(IndexType) { outer.ss << 'i'; }
92 void caseFelt(FeltType) { outer.ss << 'f'; }
93 void caseString(StringType) { outer.ss << 's'; }
94 void caseTypeVar(TypeVarType t) {
95 outer.ss << "!t<";
96 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
97 outer.ss << '>';
98 }
99 void caseArray(ArrayType t) {
100 outer.ss << "!a<";
101 outer.append(t.getElementType());
102 outer.ss << ':';
103 outer.append(t.getDimensionSizes());
104 outer.ss << '>';
105 }
106 void caseStruct(StructType t) {
107 outer.ss << "!s<";
108 outer.appendSymRef(t.getNameRef());
109 if (ArrayAttr params = t.getParams()) {
110 outer.ss << '_';
111 outer.append(params.getValue());
112 }
113 outer.ss << '>';
114 }
115 };
116 Impl(*this).match(type);
117
118 assert(
119 ret.find(PLACEHOLDER, position) == std::string::npos &&
120 "formatting a Type should not produce the 'PLACEHOLDER' char"
121 );
122 return *this;
123}
124
125BuildShortTypeString &BuildShortTypeString::append(Attribute a) {
126 // Special case for inserting the `PLACEHOLDER`
127 if (a == nullptr) {
128 ss << PLACEHOLDER;
129 return *this;
130 }
131
132 size_t position = ret.size();
133 (void)position; // tell compiler it's intentionally unused in builds without assertions
134
135 // Adapted from AsmPrinter::Impl::printAttributeImpl()
136 if (auto ia = llvm::dyn_cast<IntegerAttr>(a)) {
137 Type ty = ia.getType();
138 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
139 ia.getValue().print(ss, !isUnsigned);
140 } else if (auto sra = llvm::dyn_cast<SymbolRefAttr>(a)) {
141 appendSymRef(sra);
142 } else if (auto ta = llvm::dyn_cast<TypeAttr>(a)) {
143 append(ta.getValue());
144 } else if (auto ama = llvm::dyn_cast<AffineMapAttr>(a)) {
145 ss << "!m<";
146 // Filter to remove spaces from the affine_map representation
147 filtered_raw_ostream fs(ss, [](char c) { return c == ' '; });
148 ama.getValue().print(fs);
149 fs.flush();
150 ss << '>';
151 } else if (auto aa = llvm::dyn_cast<ArrayAttr>(a)) {
152 append(aa.getValue());
153 } else {
154 // All valid/legal cases must be covered above
156 }
157 assert(
158 ret.find(PLACEHOLDER, position) == std::string::npos &&
159 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
160 );
161 return *this;
162}
163
164BuildShortTypeString &BuildShortTypeString::append(ArrayRef<Attribute> attrs) {
165 llvm::interleave(attrs, ss, [this](Attribute a) { append(a); }, "_");
166 return *this;
167}
168
169std::string BuildShortTypeString::from(const std::string &base, ArrayRef<Attribute> attrs) {
170 BuildShortTypeString bldr;
171
172 bldr.ret.reserve(base.size() + attrs.size()); // reserve minimum space required
173
174 // First handle replacements of PLACEHOLDER
175 auto END = attrs.end();
176 auto IT = attrs.begin();
177 {
178 size_t start = 0;
179 for (size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
180 // Append original up to the PLACEHOLDER
181 bldr.ret.append(base, start, pos - start);
182 // Append the formatted Attribute
183 assert(IT != END && "must have an Attribute for every 'PLACEHOLDER' char");
184 bldr.append(*IT++);
185 }
186 // Append remaining suffix of the original
187 bldr.ret.append(base, start, base.size() - start);
188 }
189
190 // Append any remaining Attributes
191 if (IT != END) {
192 bldr.ss << '_';
193 bldr.append(ArrayRef(IT, END));
194 }
195
196 return bldr.ret;
197}
198
199namespace {
200
201template <typename... Types> class TypeList {
202
204 template <typename StreamType> struct Appender {
205
206 // single
207 template <typename Ty> static inline void append(StreamType &stream) {
208 stream << '\'' << Ty::name << '\'';
209 }
210
211 // multiple
212 template <typename First, typename Second, typename... Rest>
213 static void append(StreamType &stream) {
214 append<First>(stream);
215 stream << ", ";
216 append<Second, Rest...>(stream);
217 }
218
219 // full list with wrapping brackets
220 static inline void append(StreamType &stream) {
221 stream << '[';
222 append<Types...>(stream);
223 stream << ']';
224 }
225 };
226
227public:
228 // Checks if the provided value is an instance of any of `Types`
229 template <typename T> static inline bool matches(const T &value) {
230 return llvm::isa_and_present<Types...>(value);
231 }
232
233 static void reportInvalid(EmitErrorFn emitError, const Twine &foundName, const char *aspect) {
234 InFlightDiagnosticWrapper diag = emitError().append(aspect, " must be one of ");
235 Appender<InFlightDiagnosticWrapper>::append(diag);
236 diag.append(" but found '", foundName, '\'').report();
237 }
238
239 static inline void reportInvalid(EmitErrorFn emitError, Attribute found, const char *aspect) {
240 if (emitError) {
241 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() : "nullptr", aspect);
242 }
243 }
244
245 // Returns a comma-separated list formatted string of the names of `Types`
246 static inline std::string getNames() {
247 return buildStringViaCallback(Appender<llvm::raw_string_ostream>::append);
248 }
249};
250
253template <class... Ts> struct make_unique {
254 using type = TypeList<Ts...>;
255};
256
257template <class... Ts> struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
258
259template <class U, class... Us, class... Ts>
260struct make_unique<TypeList<U, Us...>, Ts...>
261 : std::conditional_t<
262 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
263 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
264
265template <class... Ts> using TypeListUnion = typename make_unique<Ts...>::type;
266
267// Dimensions in the ArrayType must be one of the following:
268// - Integer constants
269// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
270// - AffineMap (for array created within a loop where size depends on loop variable)
271using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
272
273// Parameters in the StructType must be one of the following:
274// - Integer constants
275// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
276// - Type
277// - AffineMap (for array of non-homogeneous structs)
278using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
279
280class AllowedTypes {
281 struct ColumnCheckData {
282 SymbolTableCollection *symbolTable = nullptr;
283 Operation *op = nullptr;
284 };
285
286 bool no_felt : 1 = false;
287 bool no_string : 1 = false;
288 bool no_non_signal_struct : 1 = false;
289 bool no_signal_struct : 1 = false;
290 bool no_array : 1 = false;
291 bool no_var : 1 = false;
292 bool no_int : 1 = false;
293 bool no_struct_params : 1 = false;
294 bool must_be_column : 1 = false;
295
296 ColumnCheckData columnCheck;
297
301 bool validColumns(StructType s) {
302 if (!must_be_column) {
303 return true;
304 }
305 assert(columnCheck.symbolTable);
306 assert(columnCheck.op);
307 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
308 }
309
310public:
311 constexpr AllowedTypes &noFelt() {
312 no_felt = true;
313 return *this;
314 }
315
316 constexpr AllowedTypes &noString() {
317 no_string = true;
318 return *this;
319 }
320
321 constexpr AllowedTypes &noStruct() {
322 no_non_signal_struct = true;
323 no_signal_struct = true;
324 return *this;
325 }
326
327 constexpr AllowedTypes &noStructExceptSignal() {
328 no_non_signal_struct = true;
329 no_signal_struct = false;
330 return *this;
331 }
332
333 constexpr AllowedTypes &noArray() {
334 no_array = true;
335 return *this;
336 }
337
338 constexpr AllowedTypes &noVar() {
339 no_var = true;
340 return *this;
341 }
342
343 constexpr AllowedTypes &noInt() {
344 no_int = true;
345 return *this;
346 }
347
348 constexpr AllowedTypes &noStructParams(bool noStructParams = true) {
349 no_struct_params = noStructParams;
350 return *this;
351 }
352
353 constexpr AllowedTypes &onlyInt() {
354 no_int = false;
355 return noFelt().noString().noStruct().noArray().noVar();
356 }
357
358 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
359 must_be_column = true;
360 columnCheck.symbolTable = &symbolTable;
361 columnCheck.op = op;
362 return *this;
363 }
364
365 // This is the main check for allowed types.
366 bool isValidTypeImpl(Type type);
367
368 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr) {
369 // In LLZK, the number of array dimensions must always be known, i.e., `hasRank()==true`
370 if (dimensionSizes.empty()) {
371 if (emitError) {
372 emitError().append("array must have at least one dimension").report();
373 }
374 return false;
375 }
376 // Rather than immediately returning on failure, we check all dimensions and aggregate to
377 // provide as many errors are possible in a single verifier run.
378 bool success = true;
379 for (Attribute a : dimensionSizes) {
380 if (!ArrayDimensionTypes::matches(a)) {
381 ArrayDimensionTypes::reportInvalid(emitError, a, "Array dimension");
382 success = false;
383 } else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
384 TypeList<IntegerAttr>::reportInvalid(emitError, a, "Concrete array dimension");
385 success = false;
386 } else if (failed(verifyAffineMapAttrType(emitError, a))) {
387 success = false;
388 } else if (failed(verifyIntAttrType(emitError, a))) {
389 success = false;
390 }
391 }
392 return success;
393 }
394
395 bool isValidArrayElemTypeImpl(Type type) {
396 // ArrayType element can be any valid type sans ArrayType itself.
397 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
398 }
399
400 bool isValidArrayTypeImpl(
401 Type elementType, ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr
402 ) {
403 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
404 return false;
405 }
406
407 // Ensure array element type is valid
408 if (!isValidArrayElemTypeImpl(elementType)) {
409 if (emitError) {
410 // Print proper message if `elementType` is not a valid LLZK type or
411 // if it's simply not the right kind of type for an array element.
412 if (succeeded(checkValidType(emitError, elementType))) {
413 emitError()
414 .append(
415 '\'', ArrayType::name, "' element type cannot be '",
416 elementType.getAbstractType().getName(), '\''
417 )
418 .report();
419 }
420 }
421 return false;
422 }
423 return true;
424 }
425
426 bool isValidArrayTypeImpl(Type type) {
427 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
428 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
429 }
430 return false;
431 }
432
433 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter (if any) except
434 // for the `no_struct_params` flag which requires that `params` is null or empty.
435 bool areValidStructTypeParams(ArrayAttr params, EmitErrorFn emitError = nullptr) {
436 if (isNullOrEmpty(params)) {
437 return true;
438 }
439 if (no_struct_params) {
440 return false;
441 }
442 bool success = true;
443 for (Attribute p : params) {
444 if (!StructParamTypes::matches(p)) {
445 StructParamTypes::reportInvalid(emitError, p, "Struct parameter");
446 success = false;
447 } else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
448 if (!isValidTypeImpl(tyAttr.getValue())) {
449 if (emitError) {
450 emitError().append("expected a valid LLZK type but found ", tyAttr.getValue()).report();
451 }
452 success = false;
453 }
454 } else if (no_var && !llvm::isa<IntegerAttr>(p)) {
455 TypeList<IntegerAttr>::reportInvalid(emitError, p, "Concrete struct parameter");
456 success = false;
457 } else if (failed(verifyAffineMapAttrType(emitError, p))) {
458 success = false;
459 } else if (failed(verifyIntAttrType(emitError, p))) {
460 success = false;
461 }
462 }
463
464 return success;
465 }
466};
467
468bool AllowedTypes::isValidTypeImpl(Type type) {
469 assert(
470 !(no_int && no_felt && no_string && no_var && no_non_signal_struct && no_signal_struct &&
471 no_array) &&
472 "All types have been deactivated"
473 );
474 struct Impl : LLZKTypeSwitch<Impl, bool> {
475 AllowedTypes &outer;
476 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
477
478 bool caseBool(IntegerType t) { return !outer.no_int && t.isSignlessInteger(1); }
479 bool caseIndex(IndexType) { return !outer.no_int; }
480 bool caseFelt(FeltType) { return !outer.no_felt; }
481 bool caseString(StringType) { return !outer.no_string; }
482 bool caseTypeVar(TypeVarType) { return !outer.no_var; }
483 bool caseArray(ArrayType t) {
484 return !outer.no_array &&
485 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
486 }
487 bool caseStruct(StructType t) {
488 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter.
489 if ((outer.no_signal_struct && outer.no_non_signal_struct) || !outer.validColumns(t)) {
490 return false;
491 }
492 return (!outer.no_signal_struct && isSignalType(t)) ||
493 (!outer.no_non_signal_struct && outer.areValidStructTypeParams(t.getParams()));
494 }
495 bool caseInvalid(Type _) { return false; }
496 };
497 return Impl(*this).match(type);
498}
499
500} // namespace
501
502bool isValidType(Type type) { return AllowedTypes().isValidTypeImpl(type); }
503
504bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op) {
505 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
506}
507
508bool isValidGlobalType(Type type) { return AllowedTypes().noVar().isValidTypeImpl(type); }
509
510bool isValidEmitEqType(Type type) {
511 return AllowedTypes().noString().noStructExceptSignal().isValidTypeImpl(type);
512}
513
514// Allowed types must align with StructParamTypes (defined below)
515bool isValidConstReadType(Type type) {
516 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
517}
518
519bool isValidArrayElemType(Type type) { return AllowedTypes().isValidArrayElemTypeImpl(type); }
520
521bool isValidArrayType(Type type) { return AllowedTypes().isValidArrayTypeImpl(type); }
522
523bool isConcreteType(Type type, bool allowStructParams) {
524 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
525}
526
527bool isSignalType(Type type) {
528 if (auto structParamTy = llvm::dyn_cast<StructType>(type)) {
529 return isSignalType(structParamTy);
530 }
531 return false;
532}
533
535 // Only check the leaf part of the reference (i.e., just the struct name itself) to allow cases
536 // where the `COMPONENT_NAME_SIGNAL` struct may be placed within some nesting of modules, as
537 // happens when it's imported via an IncludeOp.
538 return sType.getNameRef().getLeafReference() == COMPONENT_NAME_SIGNAL;
539}
540
541bool hasAffineMapAttr(Type type) {
542 bool encountered = false;
543 type.walk([&](AffineMapAttr a) {
544 encountered = true;
545 return WalkResult::interrupt();
546 });
547 return encountered;
548}
549
550bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); }
551
552uint64_t computeEmitEqCardinality(Type type) {
553 struct Impl : LLZKTypeSwitch<Impl, uint64_t> {
554 uint64_t caseBool(IntegerType) { return 1; }
555 uint64_t caseIndex(IndexType) { return 1; }
556 uint64_t caseFelt(FeltType) { return 1; }
557 uint64_t caseArray(ArrayType t) {
558 int64_t n = t.getNumElements();
559 assert(n >= 0);
560 return static_cast<uint64_t>(n);
561 }
562 uint64_t caseStruct(StructType t) {
563 if (isSignalType(t)) {
564 return 1;
565 }
566 llvm_unreachable("not a valid EmitEq type");
567 }
568 uint64_t caseString(StringType) { llvm_unreachable("not a valid EmitEq type"); }
569 uint64_t caseTypeVar(TypeVarType) { llvm_unreachable("tvar has unknown cardinality"); }
570 uint64_t caseInvalid(Type) { llvm_unreachable("not a valid LLZK type"); }
571 };
572 return Impl().match(type);
573}
574
575namespace {
576
585using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
586
587struct UnifierImpl {
588 ArrayRef<StringRef> rhsRevPrefix;
589 UnificationMap *unifications;
590 AffineInstantiations *affineToIntTracker;
591 // This optional function can be used to provide an exception to the standard unification
592 // rules and return a true/success result when it otherwise may not.
593 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
594
595 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
596 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
597 overrideSuccess(nullptr) {}
598
599 bool typeParamsUnify(
600 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
601 bool unifyDynamicSize = false
602 ) {
603 auto pred = [this, unifyDynamicSize](auto lhsAttr, auto rhsAttr) {
604 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
605 };
606 return (lhsParams.size() == rhsParams.size()) &&
607 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
608 }
609
610 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
611 this->affineToIntTracker = tracker;
612 return *this;
613 }
614
615 UnifierImpl &withOverrides(llvm::function_ref<bool(Type oldTy, Type newTy)> overrides) {
616 this->overrideSuccess = overrides;
617 return *this;
618 }
619
622 bool typeParamsUnify(
623 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, bool unifyDynamicSize = false
624 ) {
625 if (lhsParams && rhsParams) {
626 return typeParamsUnify(lhsParams.getValue(), rhsParams.getValue(), unifyDynamicSize);
627 }
628 // When one or the other is null, they're only equivalent if both are null
629 return !lhsParams && !rhsParams;
630 }
631
632 bool arrayTypesUnify(ArrayType lhs, ArrayType rhs) {
633 // Check if the element types of the two arrays can unify
634 if (!typesUnify(lhs.getElementType(), rhs.getElementType())) {
635 return false;
636 }
637 // Check if the dimension size attributes unify between the LHS and RHS
638 return typeParamsUnify(
639 lhs.getDimensionSizes(), rhs.getDimensionSizes(), /*unifyDynamicSize=*/true
640 );
641 }
642
643 bool structTypesUnify(StructType lhs, StructType rhs) {
644 // Check if it references the same StructDefOp, considering the additional RHS path prefix.
645 SmallVector<StringRef> rhsNames = getNames(rhs.getNameRef());
646 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
647 if (rhsNames != getNames(lhs.getNameRef())) {
648 return false;
649 }
650 // Check if the parameters unify between the LHS and RHS
651 return typeParamsUnify(lhs.getParams(), rhs.getParams(), /*unifyDynamicSize=*/false);
652 }
653
654 bool typesUnify(Type lhs, Type rhs) {
655 if (lhs == rhs) {
656 return true;
657 }
658 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
659 return true;
660 }
661 // A type variable can be any type, thus it unifies with anything.
662 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
663 track(Side::LHS, lhsTvar.getNameRef(), rhs);
664 return true;
665 }
666 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
667 track(Side::RHS, rhsTvar.getNameRef(), lhs);
668 return true;
669 }
670 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
671 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
672 }
673 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
674 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
675 }
676 return false;
677 }
678
679private:
680 template <typename Tracker, typename Key, typename Val>
681 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
682 auto key = std::make_pair(keyHead, side);
683 auto it = tracker.find(key);
684 if (it == tracker.end()) {
685 tracker.try_emplace(key, val);
686 } else if (it->getSecond() != val) {
687 it->second = nullptr;
688 }
689 }
690
691 void track(Side side, SymbolRefAttr symRef, Type ty) {
692 if (unifications) {
693 Attribute attr;
694 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
695 // If 'ty' is TypeVarType<@S>, just map to @S directly.
696 attr = tvar.getNameRef();
697 } else {
698 // Otherwise wrap as a TypeAttr.
699 attr = TypeAttr::get(ty);
700 }
701 assert(symRef);
702 assert(attr);
703 track(*unifications, side, symRef, attr);
704 }
705 }
706
707 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
708 if (unifications) {
709 // If 'attr' is TypeAttr<TypeVarType<@S>>, just map to @S directly.
710 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
711 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
712 attr = tvar.getNameRef();
713 }
714 }
715 assert(symRef);
716 assert(attr);
717 // If 'attr' is a SymbolRefAttr, map in both directions for the correctness of
718 // `isMoreConcreteUnification()` which relies on RHS check while other external
719 // checks on the UnificationMap may do LHS checks, and in the case of both being
720 // SymbolRefAttr, unification in either direction is possible.
721 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
722 track(*unifications, reverse(side), otherSymAttr, symRef);
723 }
724 track(*unifications, side, symRef, attr);
725 }
726 }
727
728 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
729 if (affineToIntTracker) {
730 assert(affineAttr);
731 assert(intAttr);
732 assert(!isDynamic(intAttr));
733 track(*affineToIntTracker, side, affineAttr, intAttr);
734 }
735 }
736
737 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr, bool unifyDynamicSize = false) {
740 // Straightforward equality check.
741 if (lhsAttr == rhsAttr) {
742 return true;
743 }
744 // AffineMapAttr can unify with IntegerAttr (other than kDynamic) because struct parameter
745 // instantiation will result in conversion of AffineMapAttr to IntegerAttr.
746 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
747 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
748 if (!isDynamic(rhsInt)) {
749 track(Side::LHS, lhsAffine, rhsInt);
750 return true;
751 }
752 }
753 }
754 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
755 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
756 if (!isDynamic(lhsInt)) {
757 track(Side::RHS, rhsAffine, lhsInt);
758 return true;
759 }
760 }
761 }
762 // If either side is a SymbolRefAttr, assume they unify because either flattening or a pass with
763 // a more involved value analysis is required to check if they are actually the same value.
764 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
765 track(Side::LHS, lhsSymRef, rhsAttr);
766 return true;
767 }
768 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
769 track(Side::RHS, rhsSymRef, lhsAttr);
770 return true;
771 }
772 // If either side is ShapedType::kDynamic then, similarly to Symbols, assume they unify.
773 // NOTE: Dynamic array dimensions (i.e. '?') are allowed in LLZK but should generally be
774 // restricted to scenarios where it can be replaced with a concrete value during the flattening
775 // pass, such as a `unifiable_cast` where the other side of the cast has concrete dimensions or
776 // extern functions with varargs.
777 if (unifyDynamicSize) {
778 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
779 if (IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
780 if (isDynamic(intAttr)) {
781 return intAttr;
782 }
783 }
784 return nullptr;
785 };
786 auto is_const_like = [](Attribute attr) {
787 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
788 };
789 if (IntegerAttr lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
790 if (is_const_like(rhsAttr)) {
791 return true;
792 }
793 }
794 if (IntegerAttr rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
795 if (is_const_like(lhsAttr)) {
796 return true;
797 }
798 }
799 }
800 // If both are type refs, check for unification of the types.
801 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
802 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
803 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
804 }
805 }
806 // Otherwise, they do not unify.
807 return false;
808 }
809};
810
811} // namespace
812
814 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
815 UnificationMap *unifications
816) {
817 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
818}
819
823 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, UnificationMap *unifications
824) {
825 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
826}
827
829 ArrayType lhs, ArrayType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
830) {
831 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
832}
833
835 StructType lhs, StructType rhs, ArrayRef<StringRef> rhsReversePrefix,
836 UnificationMap *unifications
837) {
838 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
839}
840
842 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
843) {
844 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
845}
846
848 Type oldTy, Type newTy, llvm::function_ref<bool(Type oldTy, Type newTy)> knownOldToNew
849) {
850 UnificationMap unifications;
851 AffineInstantiations affineInstantiations;
852 // Run type unification with the addition that affine map can become integer in the new type.
853 if (!UnifierImpl(&unifications)
854 .trackAffineToInt(&affineInstantiations)
855 .withOverrides(knownOldToNew)
856 .typesUnify(oldTy, newTy)) {
857 return false;
858 }
859
860 // If either map contains RHS-keyed mappings then the old type is "more concrete" than the new.
861 // In the UnificationMap, a RHS key would indicate that the new type contains a SymbolRef (i.e.
862 // the "least concrete" attribute kind) where the old type contained any other attribute. In the
863 // AffineInstantiations map, a RHS key would indicate that the new type contains an AffineMapAttr
864 // where the old type contains an IntegerAttr.
865 auto entryIsRHS = [](const auto &entry) { return entry.first.second == Side::RHS; };
866 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
867}
868
869FailureOr<IntegerAttr> forceIntType(IntegerAttr attr, EmitErrorFn emitError) {
870 if (llvm::isa<IndexType>(attr.getType())) {
871 return attr;
872 }
873 // Ensure the APInt is the right bitwidth for IndexType or else
874 // IntegerAttr::verify(..) will report an error.
875 APInt value = attr.getValue();
876 auto compare = value.getBitWidth() <=> IndexType::kInternalStorageBitWidth;
877 if (compare < 0) {
878 value = value.zext(IndexType::kInternalStorageBitWidth);
879 } else if (compare > 0) {
880 return emitError().append("value is too large for `index` type: ", debug::toStringOne(value));
881 }
882 return IntegerAttr::get(IndexType::get(attr.getContext()), value);
883}
884
885FailureOr<Attribute> forceIntAttrType(Attribute attr, EmitErrorFn emitError) {
886 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
887 return forceIntType(intAttr, emitError);
888 }
889 return attr;
890}
891
892FailureOr<SmallVector<Attribute>>
893forceIntAttrTypes(ArrayRef<Attribute> attrList, EmitErrorFn emitError) {
894 SmallVector<Attribute> result;
895 for (Attribute attr : attrList) {
896 FailureOr<Attribute> forced = forceIntAttrType(attr, emitError);
897 if (failed(forced)) {
898 return failure();
899 }
900 result.push_back(*forced);
901 }
902 return result;
903}
904
905LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in) {
906 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
907 Type attrTy = intAttr.getType();
908 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
909 if (emitError) {
910 emitError()
911 .append("IntegerAttr must have type 'index' or 'i1' but found '", attrTy, '\'')
912 .report();
913 }
914 return failure();
915 }
916 }
917 return success();
918}
919
920LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in) {
921 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
922 AffineMap map = affineAttr.getValue();
923 if (map.getNumResults() != 1) {
924 if (emitError) {
925 emitError()
926 .append(
927 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
928 " results"
929 )
930 .report();
931 }
932 return failure();
933 }
934 }
935 return success();
936}
937
938LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params) {
939 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
940}
941
942LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes) {
943 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
944}
945
946LogicalResult
947verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes) {
948 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
949}
950
951void assertValidAttrForParamOfType(Attribute attr) {
952 // Must be the union of valid attribute types within ArrayType, StructType, and TypeVarType.
953 using TypeVarAttrs = TypeList<SymbolRefAttr>; // per ODS spec of TypeVarType
954 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
955 llvm::report_fatal_error(
956 "Legal type parameters are inconsistent. Encountered " +
957 attr.getAbstractAttribute().getName()
958 );
959 }
960}
961
962LogicalResult
963verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType) {
964 ArrayRef<Attribute> dimsFromArr = arrayType.getDimensionSizes();
965 size_t numArrDims = dimsFromArr.size();
966 ArrayRef<Attribute> dimsFromSubArr = subArrayType.getDimensionSizes();
967 size_t numSubArrDims = dimsFromSubArr.size();
968
969 if (numArrDims < numSubArrDims) {
970 return emitError().append(
971 "subarray type ", subArrayType, " has more dimensions than array type ", arrayType
972 );
973 }
974
975 size_t toDrop = numArrDims - numSubArrDims;
976 ArrayRef<Attribute> dimsFromArrReduced = dimsFromArr.drop_front(toDrop);
977
978 // Ensure dimension sizes are compatible (ignoring the indexed dimensions)
979 if (!typeParamsUnify(dimsFromArrReduced, dimsFromSubArr)) {
980 std::string message;
981 llvm::raw_string_ostream ss(message);
982 auto appendOne = [&ss](Attribute a) { appendWithoutType(ss, a); };
983 ss << "cannot unify array dimensions [";
984 llvm::interleaveComma(dimsFromArrReduced, ss, appendOne);
985 ss << "] with [";
986 llvm::interleaveComma(dimsFromSubArr, ss, appendOne);
987 ss << "]";
988 return emitError().append(message);
989 }
990
991 // Ensure element types of the arrays are compatible
992 if (!typesUnify(arrayType.getElementType(), subArrayType.getElementType())) {
993 return emitError().append(
994 "incorrect array element type; expected: ", arrayType.getElementType(),
995 ", found: ", subArrayType.getElementType()
996 );
997 }
998
999 return success();
1000}
1001
1002LogicalResult
1003verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType) {
1004 if (auto subArrayType = llvm::dyn_cast<ArrayType>(subArrayOrElemType)) {
1005 return verifySubArrayType(emitError, arrayType, subArrayType);
1006 }
1007 if (!typesUnify(arrayType.getElementType(), subArrayOrElemType)) {
1008 return emitError().append(
1009 "incorrect array element type; expected: ", arrayType.getElementType(),
1010 ", found: ", subArrayOrElemType
1011 );
1012 }
1013
1014 return success();
1015}
1016
1017} // namespace llzk
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
Definition TypeHelper.h:36
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::mlir::SymbolRefAttr getNameRef() const
std::string toStringOne(const T &value)
Definition Debug.h:175
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:185
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:143
bool isSignalType(Type type)
int64_t fromAPInt(const llvm::APInt &i)
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
Definition Constants.h:16
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
void appendWithoutType(mlir::raw_ostream &os, mlir::Attribute a)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)