mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Fixes scalar_type definition for llvm builtin mma type
This commit is contained in:
@@ -34,6 +34,17 @@ using f4_t = unsigned _BitInt(4);
|
||||
using f6_t = _BitInt(6); // e2m3 format
|
||||
using bf6_t = unsigned _BitInt(6); // e3m2 format
|
||||
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
|
||||
// native types: bool
|
||||
template <typename T>
|
||||
inline constexpr bool is_native_type()
|
||||
{
|
||||
return is_same_v<T, double> || is_same_v<T, float> || is_same_v<T, half_t> ||
|
||||
is_same_v<T, bhalf_t> || is_same_v<T, int32_t> || is_same_v<T, uint32_t> ||
|
||||
is_same_v<T, int8_t> || is_same_v<T, uint8_t> || is_same_v<T, _BitInt(8)> ||
|
||||
is_same_v<T, unsigned _BitInt(8)> || is_same_v<T, bool>;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Wrapper for native vector type
|
||||
* @tparam T The element type of the vector
|
||||
@@ -43,11 +54,28 @@ template <typename T, index_t Rank>
|
||||
using NativeVectorT = T __attribute__((ext_vector_type(Rank)));
|
||||
|
||||
/**
|
||||
* @brief Mapping of incoming type to local native storage type and vector size
|
||||
* @brief Mapping of incoming type to local native vector storage type and vector size
|
||||
* @tparam T Incoming data type
|
||||
*/
|
||||
template <typename T>
|
||||
struct scalar_type;
|
||||
struct scalar_type
|
||||
{
|
||||
// Basic data type mapping to unsigned _BitInt of appropriate size
|
||||
using type = unsigned _BitInt(8 * sizeof(T));
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief scalar_type trait override for NativeVectorT
|
||||
* @tparam T The vector type
|
||||
* @tparam Rank The number of elements in the vector
|
||||
*/
|
||||
template <typename T, index_t Rank>
|
||||
struct scalar_type<NativeVectorT<T, Rank>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = Rank;
|
||||
};
|
||||
|
||||
struct f4x2_pk_t
|
||||
{
|
||||
@@ -85,6 +113,39 @@ struct f4x2_pk_t
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written
|
||||
// in the following way:
|
||||
// template<typename T, index_t Rank>
|
||||
// struct scalar_type<T __attribute__((__vector_size__(sizeof(T) * Rank)))>
|
||||
// {
|
||||
// using type = T;
|
||||
// static constexpr index_t vector_size = Rank;
|
||||
// };
|
||||
// The compiler errors out with "partial specialization is not allowed for this type",
|
||||
// claiming that the Rank is not a deducible parameter. This might be a compiler bug.
|
||||
// Note the above type is classified differently from the NativeVectorT<T, Rank> alias,
|
||||
// even though they are functionally equivalent and are trivially constructibe from each other.
|
||||
// This is unfortunate, but we have to work around it because some LLVM builtins for some
|
||||
// operations (e.g., mma) may return the former type.
|
||||
// For now we have to explicitly specialize for each vector size we need. These are used
|
||||
// in f6_pk_t below.
|
||||
|
||||
/// @brief scalar_type trait override for uint32_t vector of size 3
|
||||
template <>
|
||||
struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 3)))>
|
||||
{
|
||||
using type = uint32_t;
|
||||
static constexpr index_t vector_size = 3;
|
||||
};
|
||||
|
||||
/// @brief scalar_type trait override for uint32_t vector of size 6
|
||||
template <>
|
||||
struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 6)))>
|
||||
{
|
||||
using type = uint32_t;
|
||||
static constexpr index_t vector_size = 6;
|
||||
};
|
||||
|
||||
template <typename BitType, index_t pk_size>
|
||||
struct f6_pk_t
|
||||
{
|
||||
@@ -105,16 +166,36 @@ struct f6_pk_t
|
||||
|
||||
using type = f6_pk_t<BitType, packed_size>;
|
||||
|
||||
/** This class may trivially constructed by the following vector type alias
|
||||
* for example from a result of an mma operation. This is primarily for internal use.
|
||||
* @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from
|
||||
* uint32_t vectors of size 3 and 6 respectively for example from mma operation results.
|
||||
* Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a
|
||||
* NativeVectorT<uint32_t, 6> is NOT the same as __attribute__((__vector_size__(6 *
|
||||
* sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being
|
||||
* functionally equivalent. This class may be trivially constructed from both, so we can steer
|
||||
* the templated ctor below to only consider incoming vectors types other than our two storage
|
||||
* types of interest.
|
||||
*/
|
||||
using storage_type_alias =
|
||||
element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size)));
|
||||
|
||||
__host__ __device__ constexpr f6_pk_t() {}
|
||||
__host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init}
|
||||
{
|
||||
// TODO: consider removing initialization similar to vector_type<T, 256>
|
||||
}
|
||||
|
||||
// Initialize from a vector type with the same size as packed_size
|
||||
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
|
||||
// Initialize from a vector type with the same size as packed_size.
|
||||
// Exclude storage_type and storage_type_alias because these are trivially constructible.
|
||||
template <
|
||||
typename T,
|
||||
typename = enable_if_t<!is_same_v<T, storage_type> && !is_same_v<T, storage_type_alias> &&
|
||||
scalar_type<T>::vector_size == packed_size>>
|
||||
__host__ __device__ f6_pk_t(const T& v)
|
||||
{
|
||||
static_assert(scalar_type<T>::vector_size == packed_size,
|
||||
"Input vector size must match packed_size.");
|
||||
static_for<0, packed_size, 1>{}(
|
||||
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
|
||||
}
|
||||
@@ -202,17 +283,6 @@ struct pk_i4_t
|
||||
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
|
||||
// native types: bool
|
||||
template <typename T>
|
||||
inline constexpr bool is_native_type()
|
||||
{
|
||||
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
|
||||
is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
// is_scalar_type
|
||||
template <typename TV>
|
||||
struct is_scalar_type
|
||||
@@ -369,18 +439,6 @@ struct scalar_type<bf6x16_pk_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief scalar_type trait override for NativeVectorT
|
||||
* @tparam T The vector type
|
||||
* @tparam Rank The number of elements in the vector
|
||||
*/
|
||||
template <typename T, index_t Rank>
|
||||
struct scalar_type<NativeVectorT<T, Rank>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = Rank;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct packed_type_info
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user