LLZK 0.1.0
Veridise's ZK Language IR
Loading...
Searching...
No Matches
TypeHelper.cpp
Go to the documentation of this file.
1//===-- TypeHelper.cpp ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
19
20using namespace mlir;
21
22namespace llzk {
23
24using namespace array;
25using namespace component;
26using namespace felt;
27using namespace polymorphic;
28using namespace string;
29
30void BuildShortTypeString::appendSymName(StringRef str) {
31 if (str.empty()) {
32 ss << '?';
33 } else {
34 ss << '@' << str;
35 }
36}
37
38void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
39 appendSymName(sa.getRootReference().getValue());
40 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
41 ss << "::";
42 appendSymName(nestedRef.getValue());
43 }
44}
45
46BuildShortTypeString &BuildShortTypeString::append(Type type) {
47 size_t position = ret.size();
48 // Cases must be consistent with isValidTypeImpl() below.
49 if (type.isSignlessInteger(1)) {
50 ss << 'b';
51 } else if (llvm::isa<IndexType>(type)) {
52 ss << 'i';
53 } else if (llvm::isa<FeltType>(type)) {
54 ss << 'f';
55 } else if (llvm::isa<StringType>(type)) {
56 ss << 's';
57 } else if (llvm::isa<TypeVarType>(type)) {
58 ss << "!t<";
59 appendSymName(llvm::cast<TypeVarType>(type).getRefName());
60 ss << '>';
61 } else if (llvm::isa<ArrayType>(type)) {
62 ArrayType at = llvm::cast<ArrayType>(type);
63 ss << "!a<";
64 append(at.getElementType());
65 ss << ':';
66 append(at.getDimensionSizes());
67 ss << '>';
68 } else if (llvm::isa<StructType>(type)) {
69 StructType st = llvm::cast<StructType>(type);
70 ss << "!s<";
71 appendSymRef(st.getNameRef());
72 if (ArrayAttr params = st.getParams()) {
73 ss << '_';
74 append(params.getValue());
75 }
76 ss << '>';
77 } else {
78 ss << "!INVALID";
79 }
80 assert(
81 ret.find(PLACEHOLDER, position) == std::string::npos &&
82 "formatting a Type should not produce the 'PLACEHOLDER' char"
83 );
84 return *this;
85}
86
87BuildShortTypeString &BuildShortTypeString::append(Attribute a) {
88 // Special case for inserting the `PLACEHOLDER`
89 if (a == nullptr) {
90 ss << PLACEHOLDER;
91 return *this;
92 }
93
94 size_t position = ret.size();
95 // Adapted from AsmPrinter::Impl::printAttributeImpl()
96 if (llvm::isa<IntegerAttr>(a)) {
97 IntegerAttr ia = llvm::cast<IntegerAttr>(a);
98 Type ty = ia.getType();
99 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
100 ia.getValue().print(ss, !isUnsigned);
101 } else if (llvm::isa<SymbolRefAttr>(a)) {
102 appendSymRef(llvm::cast<SymbolRefAttr>(a));
103 } else if (llvm::isa<TypeAttr>(a)) {
104 append(llvm::cast<TypeAttr>(a).getValue());
105 } else if (llvm::isa<AffineMapAttr>(a)) {
106 ss << "!m<";
107 // Filter to remove spaces from the affine_map representation
108 filtered_raw_ostream fs(ss, [](char c) { return c == ' '; });
109 llvm::cast<AffineMapAttr>(a).getValue().print(fs);
110 fs.flush();
111 ss << '>';
112 } else if (llvm::isa<ArrayAttr>(a)) {
113 append(llvm::cast<ArrayAttr>(a).getValue());
114 } else {
115 // All valid/legal cases must be covered above
117 }
118 assert(
119 ret.find(PLACEHOLDER, position) == std::string::npos &&
120 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
121 );
122 return *this;
123}
124
125BuildShortTypeString &BuildShortTypeString::append(ArrayRef<Attribute> attrs) {
126 llvm::interleave(attrs, ss, [this](Attribute a) { append(a); }, "_");
127 return *this;
128}
129
130std::string BuildShortTypeString::from(const std::string &base, ArrayRef<Attribute> attrs) {
131 BuildShortTypeString bldr;
132
133 bldr.ret.reserve(base.size() + attrs.size()); // reserve minimum space required
134
135 // First handle replacements of PLACEHOLDER
136 auto END = attrs.end();
137 auto IT = attrs.begin();
138 {
139 size_t start = 0;
140 for (size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
141 // Append original up to the PLACEHOLDER
142 bldr.ret.append(base, start, pos - start);
143 // Append the formatted Attribute
144 assert(IT != END && "must have an Attribute for every 'PLACEHOLDER' char");
145 bldr.append(*IT++);
146 }
147 // Append remaining suffix of the original
148 bldr.ret.append(base, start, base.size() - start);
149 }
150
151 // Append any remaining Attributes
152 if (IT != END) {
153 bldr.ss << '_';
154 bldr.append(ArrayRef(IT, END));
155 }
156
157 return bldr.ret;
158}
159
160namespace {
161
162template <typename... Types> class TypeList {
163
165 template <typename StreamType> struct Appender {
166
167 // single
168 template <typename Ty> static inline void append(StreamType &stream) {
169 stream << '\'' << Ty::name << '\'';
170 }
171
172 // multiple
173 template <typename First, typename Second, typename... Rest>
174 static void append(StreamType &stream) {
175 append<First>(stream);
176 stream << ", ";
177 append<Second, Rest...>(stream);
178 }
179
180 // full list with wrapping brackets
181 static inline void append(StreamType &stream) {
182 stream << '[';
183 append<Types...>(stream);
184 stream << ']';
185 }
186 };
187
188public:
189 // Checks if the provided value is an instance of any of `Types`
190 template <typename T> static inline bool matches(const T &value) {
191 return llvm::isa_and_present<Types...>(value);
192 }
193
194 static void reportInvalid(EmitErrorFn emitError, const Twine &foundName, const char *aspect) {
195 InFlightDiagnostic diag = emitError().append(aspect, " must be one of ");
196 Appender<InFlightDiagnostic>::append(diag);
197 diag.append(" but found '", foundName, "'").report();
198 }
199
200 static inline void reportInvalid(EmitErrorFn emitError, Attribute found, const char *aspect) {
201 if (emitError) {
202 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() : "nullptr", aspect);
203 }
204 }
205
206 // Returns a comma-separated list formatted string of the names of `Types`
207 static inline std::string getNames() {
208 return buildStringViaCallback(Appender<llvm::raw_string_ostream>::append);
209 }
210};
211
214template <class... Ts> struct make_unique {
215 using type = TypeList<Ts...>;
216};
217
218template <class... Ts> struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
219
220template <class U, class... Us, class... Ts>
221struct make_unique<TypeList<U, Us...>, Ts...>
222 : std::conditional_t<
223 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
224 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
225
226template <class... Ts> using TypeListUnion = typename make_unique<Ts...>::type;
227
228// Dimensions in the ArrayType must be one of the following:
229// - Integer constants
230// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
231// - AffineMap (for array created within a loop where size depends on loop variable)
232using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
233
234// Parameters in the StructType must be one of the following:
235// - Integer constants
236// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
237// - Type
238// - AffineMap (for array of non-homogeneous structs)
239using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
240
241class AllowedTypes {
242 struct ColumnCheckData {
243 SymbolTableCollection *symbolTable = nullptr;
244 Operation *op = nullptr;
245 };
246
247 bool no_felt : 1 = false;
248 bool no_string : 1 = false;
249 bool no_non_signal_struct : 1 = false;
250 bool no_signal_struct : 1 = false;
251 bool no_array : 1 = false;
252 bool no_var : 1 = false;
253 bool no_int : 1 = false;
254 bool no_struct_params : 1 = false;
255 bool must_be_column : 1 = false;
256
257 ColumnCheckData columnCheck;
258
262 bool validColumns(StructType s) {
263 if (!must_be_column) {
264 return true;
265 }
266 assert(columnCheck.symbolTable);
267 assert(columnCheck.op);
268 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
269 }
270
271public:
272 constexpr AllowedTypes &noFelt() {
273 no_felt = true;
274 return *this;
275 }
276
277 constexpr AllowedTypes &noString() {
278 no_string = true;
279 return *this;
280 }
281
282 constexpr AllowedTypes &noStruct() {
283 no_non_signal_struct = true;
284 no_signal_struct = true;
285 return *this;
286 }
287
288 constexpr AllowedTypes &noStructExceptSignal() {
289 no_non_signal_struct = true;
290 no_signal_struct = false;
291 return *this;
292 }
293
294 constexpr AllowedTypes &noArray() {
295 no_array = true;
296 return *this;
297 }
298
299 constexpr AllowedTypes &noVar() {
300 no_var = true;
301 return *this;
302 }
303
304 constexpr AllowedTypes &noInt() {
305 no_int = true;
306 return *this;
307 }
308
309 constexpr AllowedTypes &noStructParams(bool noStructParams = true) {
310 no_struct_params = noStructParams;
311 return *this;
312 }
313
314 constexpr AllowedTypes &onlyInt() {
315 no_int = false;
316 return noFelt().noString().noStruct().noArray().noVar();
317 }
318
319 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
320 must_be_column = true;
321 columnCheck.symbolTable = &symbolTable;
322 columnCheck.op = op;
323 return *this;
324 }
325
326 // This is the main check for allowed types.
327 bool isValidTypeImpl(Type type);
328
329 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr) {
330 // In LLZK, the number of array dimensions must always be known, i.e. `hasRank()==true`
331 if (dimensionSizes.empty()) {
332 if (emitError) {
333 emitError().append("array must have at least one dimension").report();
334 }
335 return false;
336 }
337 // Rather than immediately returning on failure, we check all dimensions and aggregate to
338 // provide as many errors are possible in a single verifier run.
339 bool success = true;
340 for (Attribute a : dimensionSizes) {
341 if (!ArrayDimensionTypes::matches(a)) {
342 ArrayDimensionTypes::reportInvalid(emitError, a, "Array dimension");
343 success = false;
344 } else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
345 TypeList<IntegerAttr>::reportInvalid(emitError, a, "Concrete array dimension");
346 success = false;
347 } else if (failed(verifyAffineMapAttrType(emitError, a))) {
348 success = false;
349 } else if (failed(verifyIntAttrType(emitError, a))) {
350 success = false;
351 }
352 }
353 return success;
354 }
355
356 bool isValidArrayElemTypeImpl(Type type) {
357 // ArrayType element can be any valid type sans ArrayType itself.
358 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
359 }
360
361 bool isValidArrayTypeImpl(
362 Type elementType, ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr
363 ) {
364 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
365 return false;
366 }
367
368 // Ensure array element type is valid
369 if (!isValidArrayElemTypeImpl(elementType)) {
370 if (emitError) {
371 // Print proper message if `elementType` is not a valid LLZK type or
372 // if it's simply not the right kind of type for an array element.
373 if (succeeded(checkValidType(emitError, elementType))) {
374 emitError()
375 .append(
376 "'", ArrayType::name, "' element type cannot be '",
377 elementType.getAbstractType().getName(), "'"
378 )
379 .report();
380 }
381 }
382 return false;
383 }
384 return true;
385 }
386
387 bool isValidArrayTypeImpl(Type type) {
388 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
389 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
390 }
391 return false;
392 }
393
394 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter (if any) except
395 // for the `no_struct_params` flag which requires that `params` is null or empty.
396 bool areValidStructTypeParams(ArrayAttr params, EmitErrorFn emitError = nullptr) {
397 if (isNullOrEmpty(params)) {
398 return true;
399 }
400 if (no_struct_params) {
401 return false;
402 }
403 bool success = true;
404 for (Attribute p : params) {
405 if (!StructParamTypes::matches(p)) {
406 StructParamTypes::reportInvalid(emitError, p, "Struct parameter");
407 success = false;
408 } else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
409 if (!isValidTypeImpl(tyAttr.getValue())) {
410 if (emitError) {
411 emitError().append("expected a valid LLZK type but found ", tyAttr.getValue()).report();
412 }
413 success = false;
414 }
415 } else if (no_var && !llvm::isa<IntegerAttr>(p)) {
416 TypeList<IntegerAttr>::reportInvalid(emitError, p, "Concrete struct parameter");
417 success = false;
418 } else if (failed(verifyAffineMapAttrType(emitError, p))) {
419 success = false;
420 } else if (failed(verifyIntAttrType(emitError, p))) {
421 success = false;
422 }
423 }
424
425 return success;
426 }
427
428 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter.
429 bool isValidStructTypeImpl(Type type, bool allowSignalStruct, bool allowNonSignalStruct) {
430 if (!allowSignalStruct && !allowNonSignalStruct) {
431 return false;
432 }
433 if (StructType sType = llvm::dyn_cast<StructType>(type); sType && validColumns(sType)) {
434 return (allowSignalStruct && isSignalType(sType)) ||
435 (allowNonSignalStruct && areValidStructTypeParams(sType.getParams()));
436 }
437 return false;
438 }
439};
440
441bool AllowedTypes::isValidTypeImpl(Type type) {
442 assert(
443 !(no_int && no_felt && no_string && no_var && no_non_signal_struct && no_signal_struct &&
444 no_array) &&
445 "All types have been deactivated"
446 );
447 return (!no_int && type.isSignlessInteger(1)) || (!no_int && llvm::isa<IndexType>(type)) ||
448 (!no_felt && llvm::isa<FeltType>(type)) || (!no_string && llvm::isa<StringType>(type)) ||
449 (!no_var && llvm::isa<TypeVarType>(type)) || (!no_array && isValidArrayTypeImpl(type)) ||
450 isValidStructTypeImpl(type, !no_signal_struct, !no_non_signal_struct);
451}
452
453} // namespace
454
455bool isValidType(Type type) { return AllowedTypes().isValidTypeImpl(type); }
456
457bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op) {
458 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
459}
460
461bool isValidGlobalType(Type type) { return AllowedTypes().noVar().isValidTypeImpl(type); }
462
463bool isValidEmitEqType(Type type) {
464 return AllowedTypes().noString().noStructExceptSignal().isValidTypeImpl(type);
465}
466
467// Allowed types must align with StructParamTypes (defined below)
468bool isValidConstReadType(Type type) {
469 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
470}
471
472bool isValidArrayElemType(Type type) { return AllowedTypes().isValidArrayElemTypeImpl(type); }
473
474bool isValidArrayType(Type type) { return AllowedTypes().isValidArrayTypeImpl(type); }
475
476bool isConcreteType(Type type, bool allowStructParams) {
477 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
478}
479
480bool isSignalType(Type type) {
481 if (auto structParamTy = llvm::dyn_cast<StructType>(type)) {
482 return isSignalType(structParamTy);
483 }
484 return false;
485}
486
488 // Only check the leaf part of the reference (i.e. just the struct name itself) to allow cases
489 // where the `COMPONENT_NAME_SIGNAL` struct may be placed within some nesting of modules, as
490 // happens when it's imported via an IncludeOp.
491 return sType.getNameRef().getLeafReference() == COMPONENT_NAME_SIGNAL;
492}
493
494bool hasAffineMapAttr(Type type) {
495 bool encountered = false;
496 type.walk([&](AffineMapAttr a) {
497 encountered = true;
498 return WalkResult::interrupt();
499 });
500 return encountered;
501}
502
503bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); }
504
505namespace {
506
515using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
516
517struct UnifierImpl {
518 ArrayRef<StringRef> rhsRevPrefix;
519 UnificationMap *unifications;
520 AffineInstantiations *affineToIntTracker;
521 // This optional function can be used to provide an exception to the standard unification
522 // rules and return a true/success result when it otherwise may not.
523 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
524
525 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
526 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
527 overrideSuccess(nullptr) {}
528
529 bool typeParamsUnify(
530 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
531 bool unifyDynamicSize = false
532 ) {
533 auto pred = [this, unifyDynamicSize](auto lhsAttr, auto rhsAttr) {
534 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
535 };
536 return (lhsParams.size() == rhsParams.size()) &&
537 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
538 }
539
540 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
541 this->affineToIntTracker = tracker;
542 return *this;
543 }
544
545 UnifierImpl &withOverrides(llvm::function_ref<bool(Type oldTy, Type newTy)> overrides) {
546 this->overrideSuccess = overrides;
547 return *this;
548 }
549
552 bool typeParamsUnify(
553 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, bool unifyDynamicSize = false
554 ) {
555 if (lhsParams && rhsParams) {
556 return typeParamsUnify(lhsParams.getValue(), rhsParams.getValue(), unifyDynamicSize);
557 }
558 // When one or the other is null, they're only equivalent if both are null
559 return !lhsParams && !rhsParams;
560 }
561
562 bool arrayTypesUnify(ArrayType lhs, ArrayType rhs) {
563 // Check if the element types of the two arrays can unify
564 if (!typesUnify(lhs.getElementType(), rhs.getElementType())) {
565 return false;
566 }
567 // Check if the dimension size attributes unify between the LHS and RHS
568 return typeParamsUnify(
569 lhs.getDimensionSizes(), rhs.getDimensionSizes(), /*unifyDynamicSize=*/true
570 );
571 }
572
573 bool structTypesUnify(StructType lhs, StructType rhs) {
574 // Check if it references the same StructDefOp, considering the additional RHS path prefix.
575 SmallVector<StringRef> rhsNames = getNames(rhs.getNameRef());
576 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
577 if (rhsNames != getNames(lhs.getNameRef())) {
578 return false;
579 }
580 // Check if the parameters unify between the LHS and RHS
581 return typeParamsUnify(lhs.getParams(), rhs.getParams());
582 }
583
584 bool typesUnify(Type lhs, Type rhs) {
585 if (lhs == rhs) {
586 return true;
587 }
588 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
589 return true;
590 }
591 // A type variable can be any type, thus it unifies with anything.
592 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
593 track(Side::LHS, lhsTvar.getNameRef(), rhs);
594 return true;
595 }
596 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
597 track(Side::RHS, rhsTvar.getNameRef(), lhs);
598 return true;
599 }
600 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
601 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
602 }
603 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
604 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
605 }
606 return false;
607 }
608
609private:
610 template <typename Tracker, typename Key, typename Val>
611 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
612 auto key = std::make_pair(keyHead, side);
613 auto it = tracker.find(key);
614 if (it == tracker.end()) {
615 tracker.try_emplace(key, val);
616 } else if (it->getSecond() != val) {
617 it->second = nullptr;
618 }
619 }
620
621 void track(Side side, SymbolRefAttr symRef, Type ty) {
622 if (unifications) {
623 Attribute attr;
624 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
625 // If 'ty' is TypeVarType<@S>, just map to @S directly.
626 attr = tvar.getNameRef();
627 } else {
628 // Otherwise wrap as a TypeAttr.
629 attr = TypeAttr::get(ty);
630 }
631 assert(symRef);
632 assert(attr);
633 track(*unifications, side, symRef, attr);
634 }
635 }
636
637 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
638 if (unifications) {
639 // If 'attr' is TypeAttr<TypeVarType<@S>>, just map to @S directly.
640 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
641 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
642 attr = tvar.getNameRef();
643 }
644 }
645 assert(symRef);
646 assert(attr);
647 // If 'attr' is a SymbolRefAttr, map in both directions for the correctness of
648 // `isMoreConcreteUnification()` which relies on RHS check while other external
649 // checks on the UnificationMap may do LHS checks, and in the case of both being
650 // SymbolRefAttr, unification in either direction is possible.
651 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
652 track(*unifications, reverse(side), otherSymAttr, symRef);
653 }
654 track(*unifications, side, symRef, attr);
655 }
656 }
657
658 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
659 if (affineToIntTracker) {
660 assert(affineAttr);
661 assert(intAttr);
662 assert(!isDynamic(intAttr));
663 track(*affineToIntTracker, side, affineAttr, intAttr);
664 }
665 }
666
667 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr, bool unifyDynamicSize = false) {
670 // Straightforward equality check.
671 if (lhsAttr == rhsAttr) {
672 return true;
673 }
674 // AffineMapAttr can unify with IntegerAttr (other than kDynamic) because struct parameter
675 // instantiation will result in conversion of AffineMapAttr to IntegerAttr.
676 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
677 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
678 if (!isDynamic(rhsInt)) {
679 track(Side::LHS, lhsAffine, rhsInt);
680 return true;
681 }
682 }
683 }
684 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
685 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
686 if (!isDynamic(lhsInt)) {
687 track(Side::RHS, rhsAffine, lhsInt);
688 return true;
689 }
690 }
691 }
692 // If either side is a SymbolRefAttr, assume they unify because either flattening or a pass with
693 // a more involved value analysis is required to check if they are actually the same value.
694 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
695 track(Side::LHS, lhsSymRef, rhsAttr);
696 return true;
697 }
698 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
699 track(Side::RHS, rhsSymRef, lhsAttr);
700 return true;
701 }
702 // If either side is ShapedType::kDynamic then, similarly to Symbols, assume they unify.
703 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
704 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
705 if (isDynamic(intAttr)) {
706 return intAttr;
707 }
708 }
709 return nullptr;
710 };
711 auto isa_const = [](Attribute attr) {
712 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
713 };
714 if (auto lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
715 if (isa_const(rhsAttr)) {
716 return true;
717 }
718 }
719 if (auto rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
720 if (isa_const(lhsAttr)) {
721 return true;
722 }
723 }
724 // If both are type refs, check for unification of the types.
725 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
726 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
727 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
728 }
729 }
730 // Otherwise, they do not unify.
731 return false;
732 }
733};
734
735} // namespace
736
738 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
739 UnificationMap *unifications
740) {
741 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
742}
743
747 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, UnificationMap *unifications
748) {
749 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
750}
751
753 ArrayType lhs, ArrayType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
754) {
755 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
756}
757
759 StructType lhs, StructType rhs, ArrayRef<StringRef> rhsReversePrefix,
760 UnificationMap *unifications
761) {
762 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
763}
764
766 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
767) {
768 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
769}
770
772 Type oldTy, Type newTy, llvm::function_ref<bool(Type oldTy, Type newTy)> knownOldToNew
773) {
774 UnificationMap unifications;
775 AffineInstantiations affineInstantiations;
776 // Run type unification with the addition that affine map can become integer in the new type.
777 if (!UnifierImpl(&unifications)
778 .trackAffineToInt(&affineInstantiations)
779 .withOverrides(knownOldToNew)
780 .typesUnify(oldTy, newTy)) {
781 return false;
782 }
783
784 // If either map contains RHS-keyed mappings then the old type is "more concrete" than the new.
785 // In the UnificationMap, a RHS key would indicate that the new type contains a SymbolRef (i.e.
786 // the "least concrete" attribute kind) where the old type contained any other attribute. In the
787 // AffineInstantiations map, a RHS key would indicate that the new type contains an AffineMapAttr
788 // where the old type contains an IntegerAttr.
789 auto entryIsRHS = [](const auto &entry) { return entry.first.second == Side::RHS; };
790 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
791}
792
793IntegerAttr forceIntType(IntegerAttr attr) {
794 if (AllowedTypes().onlyInt().isValidTypeImpl(attr.getType())) {
795 return attr;
796 }
797 return IntegerAttr::get(IndexType::get(attr.getContext()), attr.getValue());
798}
799
800Attribute forceIntAttrType(Attribute attr) {
801 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
802 return forceIntType(intAttr);
803 }
804 return attr;
805}
806
807SmallVector<Attribute> forceIntAttrTypes(ArrayRef<Attribute> attrList) {
808 return llvm::map_to_vector(attrList, forceIntAttrType);
809}
810
811LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in) {
812 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
813 Type attrTy = intAttr.getType();
814 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
815 if (emitError) {
816 emitError()
817 .append("IntegerAttr must have type 'index' or 'i1' but found '", attrTy, "'")
818 .report();
819 }
820 return failure();
821 }
822 }
823 return success();
824}
825
826LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in) {
827 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
828 AffineMap map = affineAttr.getValue();
829 if (map.getNumResults() != 1) {
830 if (emitError) {
831 emitError()
832 .append(
833 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
834 " results"
835 )
836 .report();
837 }
838 return failure();
839 }
840 }
841 return success();
842}
843
844LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params) {
845 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
846}
847
848LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes) {
849 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
850}
851
852LogicalResult
853verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes) {
854 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
855}
856
857void assertValidAttrForParamOfType(Attribute attr) {
858 // Must be the union of valid attribute types within ArrayType, StructType, and TypeVarType.
859 using TypeVarAttrs = TypeList<SymbolRefAttr>; // per ODS spec of TypeVarType
860 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
861 llvm::report_fatal_error(
862 "Legal type parameters are inconsistent. Encountered " +
863 attr.getAbstractAttribute().getName()
864 );
865 }
866}
867
868} // namespace llzk
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
Definition TypeHelper.h:36
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::mlir::SymbolRefAttr getNameRef() const
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
llvm::function_ref< mlir::InFlightDiagnostic()> EmitErrorFn
Definition ErrorHelper.h:18
bool isValidGlobalType(Type type)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
SmallVector< Attribute > forceIntAttrTypes(ArrayRef< Attribute > attrList)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:181
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:143
bool isSignalType(Type type)
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
Definition Constants.h:16
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
Attribute forceIntAttrType(Attribute attr)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
IntegerAttr forceIntType(IntegerAttr attr)
bool hasAffineMapAttr(Type type)
int64_t fromAPInt(llvm::APInt i)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:107
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)