diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index b2e6061d45..ff0bb10d0c 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -34,6 +34,17 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, +// native types: bool +template +inline constexpr bool is_native_type() +{ + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + /** * @brief Wrapper for native vector type * @tparam T The element type of the vector @@ -43,11 +54,28 @@ template using NativeVectorT = T __attribute__((ext_vector_type(Rank))); /** - * @brief Mapping of incoming type to local native storage type and vector size + * @brief Mapping of incoming type to local native vector storage type and vector size * @tparam T Incoming data type */ template -struct scalar_type; +struct scalar_type +{ + // Basic data type mapping to unsigned _BitInt of appropriate size + using type = unsigned _BitInt(8 * sizeof(T)); + static constexpr index_t vector_size = 1; +}; + +/** + * @brief scalar_type trait override for NativeVectorT + * @tparam T The vector type + * @tparam Rank The number of elements in the vector + */ +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = Rank; +}; struct f4x2_pk_t { @@ -85,6 +113,39 @@ struct f4x2_pk_t } }; +// TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written +// in the following way: +// template +// struct scalar_type +// { +// using type = T; +// static constexpr index_t vector_size = Rank; +// }; +// The compiler errors out with "partial specialization is not allowed for this type", +// claiming that the Rank is not a deducible parameter. This might be a compiler bug. +// Note the above type is classified differently from the NativeVectorT alias, +// even though they are functionally equivalent and are trivially constructibe from each other. +// This is unfortunate, but we have to work around it because some LLVM builtins for some +// operations (e.g., mma) may return the former type. +// For now we have to explicitly specialize for each vector size we need. These are used +// in f6_pk_t below. + +/// @brief scalar_type trait override for uint32_t vector of size 3 +template <> +struct scalar_type +{ + using type = uint32_t; + static constexpr index_t vector_size = 3; +}; + +/// @brief scalar_type trait override for uint32_t vector of size 6 +template <> +struct scalar_type +{ + using type = uint32_t; + static constexpr index_t vector_size = 6; +}; + template struct f6_pk_t { @@ -105,16 +166,36 @@ struct f6_pk_t using type = f6_pk_t; + /** This class may trivially constructed by the following vector type alias + * for example from a result of an mma operation. This is primarily for internal use. + * @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from + * uint32_t vectors of size 3 and 6 respectively for example from mma operation results. + * Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a + * NativeVectorT is NOT the same as __attribute__((__vector_size__(6 * + * sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being + * functionally equivalent. This class may be trivially constructed from both, so we can steer + * the templated ctor below to only consider incoming vectors types other than our two storage + * types of interest. + */ + using storage_type_alias = + element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size))); + __host__ __device__ constexpr f6_pk_t() {} __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init} { // TODO: consider removing initialization similar to vector_type } - // Initialize from a vector type with the same size as packed_size - template ::vector_size == packed_size>> + // Initialize from a vector type with the same size as packed_size. + // Exclude storage_type and storage_type_alias because these are trivially constructible. + template < + typename T, + typename = enable_if_t && !is_same_v && + scalar_type::vector_size == packed_size>> __host__ __device__ f6_pk_t(const T& v) { + static_assert(scalar_type::vector_size == packed_size, + "Input vector size must match packed_size."); static_for<0, packed_size, 1>{}( [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); } @@ -202,17 +283,6 @@ struct pk_i4_t __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} }; -// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool -template -inline constexpr bool is_native_type() -{ - return is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same_v || is_same_v || is_same::value; -} - // is_scalar_type template struct is_scalar_type @@ -369,18 +439,6 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -/** - * @brief scalar_type trait override for NativeVectorT - * @tparam T The vector type - * @tparam Rank The number of elements in the vector - */ -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = Rank; -}; - template struct packed_type_info {