30void BuildShortTypeString::appendSymName(StringRef str) {
38void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
39 appendSymName(sa.getRootReference().getValue());
40 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
42 appendSymName(nestedRef.getValue());
47 size_t position = ret.size();
49 if (type.isSignlessInteger(1)) {
51 }
else if (llvm::isa<IndexType>(type)) {
53 }
else if (llvm::isa<FeltType>(type)) {
55 }
else if (llvm::isa<StringType>(type)) {
57 }
else if (llvm::isa<TypeVarType>(type)) {
59 appendSymName(llvm::cast<TypeVarType>(type).getRefName());
61 }
else if (llvm::isa<ArrayType>(type)) {
62 ArrayType at = llvm::cast<ArrayType>(type);
64 append(at.getElementType());
66 append(at.getDimensionSizes());
68 }
else if (llvm::isa<StructType>(type)) {
69 StructType st = llvm::cast<StructType>(type);
71 appendSymRef(st.getNameRef());
72 if (ArrayAttr params = st.getParams()) {
74 append(params.getValue());
81 ret.find(PLACEHOLDER, position) == std::string::npos &&
82 "formatting a Type should not produce the 'PLACEHOLDER' char"
94 size_t position = ret.size();
96 if (llvm::isa<IntegerAttr>(a)) {
97 IntegerAttr ia = llvm::cast<IntegerAttr>(a);
98 Type ty = ia.getType();
99 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
100 ia.getValue().print(ss, !isUnsigned);
101 }
else if (llvm::isa<SymbolRefAttr>(a)) {
102 appendSymRef(llvm::cast<SymbolRefAttr>(a));
103 }
else if (llvm::isa<TypeAttr>(a)) {
104 append(llvm::cast<TypeAttr>(a).getValue());
105 }
else if (llvm::isa<AffineMapAttr>(a)) {
108 filtered_raw_ostream fs(ss, [](
char c) {
return c ==
' '; });
109 llvm::cast<AffineMapAttr>(a).getValue().print(fs);
112 }
else if (llvm::isa<ArrayAttr>(a)) {
113 append(llvm::cast<ArrayAttr>(a).getValue());
119 ret.find(PLACEHOLDER, position) == std::string::npos &&
120 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
126 llvm::interleave(attrs, ss, [
this](Attribute a) { append(a); },
"_");
131 BuildShortTypeString bldr;
133 bldr.ret.reserve(base.size() + attrs.size());
136 auto END = attrs.end();
137 auto IT = attrs.begin();
140 for (
size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
142 bldr.ret.append(base, start, pos - start);
144 assert(IT != END &&
"must have an Attribute for every 'PLACEHOLDER' char");
148 bldr.ret.append(base, start, base.size() - start);
154 bldr.append(ArrayRef(IT, END));
162template <
typename... Types>
class TypeList {
165 template <
typename StreamType>
struct Appender {
168 template <
typename Ty>
static inline void append(StreamType &stream) {
169 stream <<
'\'' << Ty::name <<
'\'';
173 template <
typename First,
typename Second,
typename... Rest>
174 static void append(StreamType &stream) {
175 append<First>(stream);
177 append<Second, Rest...>(stream);
181 static inline void append(StreamType &stream) {
183 append<Types...>(stream);
190 template <
typename T>
static inline bool matches(
const T &value) {
191 return llvm::isa_and_present<Types...>(value);
194 static void reportInvalid(
EmitErrorFn emitError,
const Twine &foundName,
const char *aspect) {
195 InFlightDiagnostic diag = emitError().append(aspect,
" must be one of ");
196 Appender<InFlightDiagnostic>::append(diag);
197 diag.append(
" but found '", foundName,
"'").report();
200 static inline void reportInvalid(
EmitErrorFn emitError, Attribute found,
const char *aspect) {
202 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() :
"nullptr", aspect);
207 static inline std::string
getNames() {
214template <
class... Ts>
struct make_unique {
215 using type = TypeList<Ts...>;
218template <
class... Ts>
struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
220template <
class U,
class... Us,
class... Ts>
221struct make_unique<TypeList<U, Us...>, Ts...>
222 : std::conditional_t<
223 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
224 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
226template <
class... Ts>
using TypeListUnion =
typename make_unique<Ts...>::type;
232using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
239using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
242 struct ColumnCheckData {
243 SymbolTableCollection *symbolTable =
nullptr;
244 Operation *op =
nullptr;
247 bool no_felt : 1 =
false;
248 bool no_string : 1 =
false;
249 bool no_non_signal_struct : 1 =
false;
250 bool no_signal_struct : 1 =
false;
251 bool no_array : 1 =
false;
252 bool no_var : 1 =
false;
253 bool no_int : 1 =
false;
254 bool no_struct_params : 1 =
false;
255 bool must_be_column : 1 =
false;
257 ColumnCheckData columnCheck;
262 bool validColumns(StructType s) {
263 if (!must_be_column) {
266 assert(columnCheck.symbolTable);
267 assert(columnCheck.op);
268 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
272 constexpr AllowedTypes &noFelt() {
277 constexpr AllowedTypes &noString() {
282 constexpr AllowedTypes &noStruct() {
283 no_non_signal_struct =
true;
284 no_signal_struct =
true;
288 constexpr AllowedTypes &noStructExceptSignal() {
289 no_non_signal_struct =
true;
290 no_signal_struct =
false;
294 constexpr AllowedTypes &noArray() {
299 constexpr AllowedTypes &noVar() {
304 constexpr AllowedTypes &noInt() {
309 constexpr AllowedTypes &noStructParams(
bool noStructParams =
true) {
310 no_struct_params = noStructParams;
314 constexpr AllowedTypes &onlyInt() {
316 return noFelt().noString().noStruct().noArray().noVar();
319 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
320 must_be_column =
true;
321 columnCheck.symbolTable = &symbolTable;
327 bool isValidTypeImpl(Type type);
329 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes,
EmitErrorFn emitError =
nullptr) {
331 if (dimensionSizes.empty()) {
333 emitError().append(
"array must have at least one dimension").report();
340 for (Attribute a : dimensionSizes) {
341 if (!ArrayDimensionTypes::matches(a)) {
342 ArrayDimensionTypes::reportInvalid(emitError, a,
"Array dimension");
344 }
else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
345 TypeList<IntegerAttr>::reportInvalid(emitError, a,
"Concrete array dimension");
356 bool isValidArrayElemTypeImpl(Type type) {
358 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
361 bool isValidArrayTypeImpl(
362 Type elementType, ArrayRef<Attribute> dimensionSizes,
EmitErrorFn emitError =
nullptr
364 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
369 if (!isValidArrayElemTypeImpl(elementType)) {
377 elementType.getAbstractType().getName(),
"'"
387 bool isValidArrayTypeImpl(Type type) {
388 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
389 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
396 bool areValidStructTypeParams(ArrayAttr params,
EmitErrorFn emitError =
nullptr) {
400 if (no_struct_params) {
404 for (Attribute p : params) {
405 if (!StructParamTypes::matches(p)) {
406 StructParamTypes::reportInvalid(emitError, p,
"Struct parameter");
408 }
else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
409 if (!isValidTypeImpl(tyAttr.getValue())) {
411 emitError().append(
"expected a valid LLZK type but found ", tyAttr.getValue()).report();
415 }
else if (no_var && !llvm::isa<IntegerAttr>(p)) {
416 TypeList<IntegerAttr>::reportInvalid(emitError, p,
"Concrete struct parameter");
429 bool isValidStructTypeImpl(Type type,
bool allowSignalStruct,
bool allowNonSignalStruct) {
430 if (!allowSignalStruct && !allowNonSignalStruct) {
433 if (StructType sType = llvm::dyn_cast<StructType>(type); sType && validColumns(sType)) {
435 (allowNonSignalStruct && areValidStructTypeParams(sType.getParams()));
441bool AllowedTypes::isValidTypeImpl(Type type) {
443 !(no_int && no_felt && no_string && no_var && no_non_signal_struct && no_signal_struct &&
445 "All types have been deactivated"
447 return (!no_int && type.isSignlessInteger(1)) || (!no_int && llvm::isa<IndexType>(type)) ||
448 (!no_felt && llvm::isa<FeltType>(type)) || (!no_string && llvm::isa<StringType>(type)) ||
449 (!no_var && llvm::isa<TypeVarType>(type)) || (!no_array && isValidArrayTypeImpl(type)) ||
450 isValidStructTypeImpl(type, !no_signal_struct, !no_non_signal_struct);
455bool isValidType(Type type) {
return AllowedTypes().isValidTypeImpl(type); }
458 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
464 return AllowedTypes().noString().noStructExceptSignal().isValidTypeImpl(type);
469 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
477 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
481 if (
auto structParamTy = llvm::dyn_cast<StructType>(type)) {
495 bool encountered =
false;
496 type.walk([&](AffineMapAttr a) {
498 return WalkResult::interrupt();
515using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
518 ArrayRef<StringRef> rhsRevPrefix;
519 UnificationMap *unifications;
520 AffineInstantiations *affineToIntTracker;
523 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
525 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
526 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
527 overrideSuccess(nullptr) {}
530 const ArrayRef<Attribute> &lhsParams,
const ArrayRef<Attribute> &rhsParams,
531 bool unifyDynamicSize =
false
533 auto pred = [
this, unifyDynamicSize](
auto lhsAttr,
auto rhsAttr) {
534 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
536 return (lhsParams.size() == rhsParams.size()) &&
537 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
540 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
541 this->affineToIntTracker = tracker;
545 UnifierImpl &withOverrides(llvm::function_ref<
bool(Type oldTy, Type newTy)> overrides) {
546 this->overrideSuccess = overrides;
553 const ArrayAttr &lhsParams,
const ArrayAttr &rhsParams,
bool unifyDynamicSize =
false
555 if (lhsParams && rhsParams) {
556 return typeParamsUnify(lhsParams.getValue(), rhsParams.getValue(), unifyDynamicSize);
559 return !lhsParams && !rhsParams;
564 if (!
typesUnify(lhs.getElementType(), rhs.getElementType())) {
569 lhs.getDimensionSizes(), rhs.getDimensionSizes(),
true
575 SmallVector<StringRef> rhsNames =
getNames(rhs.getNameRef());
576 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
577 if (rhsNames !=
getNames(lhs.getNameRef())) {
588 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
592 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
593 track(Side::LHS, lhsTvar.getNameRef(), rhs);
596 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
597 track(Side::RHS, rhsTvar.getNameRef(), lhs);
600 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
601 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
603 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
604 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
610 template <
typename Tracker,
typename Key,
typename Val>
611 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
612 auto key = std::make_pair(keyHead, side);
613 auto it = tracker.find(key);
614 if (it == tracker.end()) {
615 tracker.try_emplace(key, val);
616 }
else if (it->getSecond() != val) {
617 it->second =
nullptr;
621 void track(Side side, SymbolRefAttr symRef, Type ty) {
624 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
626 attr = tvar.getNameRef();
629 attr = TypeAttr::get(ty);
633 track(*unifications, side, symRef, attr);
637 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
640 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
641 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
642 attr = tvar.getNameRef();
651 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
652 track(*unifications,
reverse(side), otherSymAttr, symRef);
654 track(*unifications, side, symRef, attr);
658 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
659 if (affineToIntTracker) {
663 track(*affineToIntTracker, side, affineAttr, intAttr);
667 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr,
bool unifyDynamicSize =
false) {
671 if (lhsAttr == rhsAttr) {
676 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
677 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
679 track(Side::LHS, lhsAffine, rhsInt);
684 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
685 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
687 track(Side::RHS, rhsAffine, lhsInt);
694 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
695 track(Side::LHS, lhsSymRef, rhsAttr);
698 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
699 track(Side::RHS, rhsSymRef, lhsAttr);
703 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
704 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
711 auto isa_const = [](Attribute attr) {
712 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
714 if (
auto lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
715 if (isa_const(rhsAttr)) {
719 if (
auto rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
720 if (isa_const(lhsAttr)) {
725 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
726 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
727 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
738 const ArrayRef<Attribute> &lhsParams,
const ArrayRef<Attribute> &rhsParams,
741 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
747 const ArrayAttr &lhsParams,
const ArrayAttr &rhsParams,
UnificationMap *unifications
749 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
755 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
762 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
766 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix,
UnificationMap *unifications
768 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
772 Type oldTy, Type newTy, llvm::function_ref<
bool(Type oldTy, Type newTy)> knownOldToNew
775 AffineInstantiations affineInstantiations;
777 if (!UnifierImpl(&unifications)
778 .trackAffineToInt(&affineInstantiations)
779 .withOverrides(knownOldToNew)
789 auto entryIsRHS = [](
const auto &entry) {
return entry.first.second ==
Side::RHS; };
790 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
794 if (AllowedTypes().onlyInt().isValidTypeImpl(attr.getType())) {
797 return IntegerAttr::get(IndexType::get(attr.getContext()), attr.getValue());
801 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
812 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
813 Type attrTy = intAttr.getType();
814 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
817 .append(
"IntegerAttr must have type 'index' or 'i1' but found '", attrTy,
"'")
827 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
828 AffineMap map = affineAttr.getValue();
829 if (map.getNumResults() != 1) {
833 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
845 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
849 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
854 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
859 using TypeVarAttrs = TypeList<SymbolRefAttr>;
860 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
861 llvm::report_fatal_error(
862 "Legal type parameters are inconsistent. Encountered " +
863 attr.getAbstractAttribute().getName()
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
static constexpr ::llvm::StringLiteral name
::mlir::SymbolRefAttr getNameRef() const
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
llvm::function_ref< mlir::InFlightDiagnostic()> EmitErrorFn
bool isValidGlobalType(Type type)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
SmallVector< Attribute > forceIntAttrTypes(ArrayRef< Attribute > attrList)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
bool isSignalType(Type type)
constexpr char COMPONENT_NAME_SIGNAL[]
Symbol name for the struct/component representing a signal.
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
Attribute forceIntAttrType(Attribute attr)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
IntegerAttr forceIntType(IntegerAttr attr)
bool hasAffineMapAttr(Type type)
int64_t fromAPInt(llvm::APInt i)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)