Fixes scalar_type definition for llvm builtin mma type

This commit is contained in:
Chris Millette
2026-01-29 13:33:38 -05:00
parent b731dc17d1
commit 13a5177923

View File

@@ -34,6 +34,17 @@ using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same_v<T, double> || is_same_v<T, float> || is_same_v<T, half_t> ||
is_same_v<T, bhalf_t> || is_same_v<T, int32_t> || is_same_v<T, uint32_t> ||
is_same_v<T, int8_t> || is_same_v<T, uint8_t> || is_same_v<T, _BitInt(8)> ||
is_same_v<T, unsigned _BitInt(8)> || is_same_v<T, bool>;
}
/**
* @brief Wrapper for native vector type
* @tparam T The element type of the vector
@@ -43,11 +54,28 @@ template <typename T, index_t Rank>
using NativeVectorT = T __attribute__((ext_vector_type(Rank)));
/**
* @brief Mapping of incoming type to local native storage type and vector size
* @brief Mapping of incoming type to local native vector storage type and vector size
* @tparam T Incoming data type
*/
template <typename T>
struct scalar_type;
struct scalar_type
{
// Basic data type mapping to unsigned _BitInt of appropriate size
using type = unsigned _BitInt(8 * sizeof(T));
static constexpr index_t vector_size = 1;
};
/**
* @brief scalar_type trait override for NativeVectorT
* @tparam T The vector type
* @tparam Rank The number of elements in the vector
*/
template <typename T, index_t Rank>
struct scalar_type<NativeVectorT<T, Rank>>
{
using type = T;
static constexpr index_t vector_size = Rank;
};
struct f4x2_pk_t
{
@@ -85,6 +113,39 @@ struct f4x2_pk_t
}
};
// TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written
// in the following way:
// template<typename T, index_t Rank>
// struct scalar_type<T __attribute__((__vector_size__(sizeof(T) * Rank)))>
// {
// using type = T;
// static constexpr index_t vector_size = Rank;
// };
// The compiler errors out with "partial specialization is not allowed for this type",
// claiming that the Rank is not a deducible parameter. This might be a compiler bug.
// Note the above type is classified differently from the NativeVectorT<T, Rank> alias,
// even though they are functionally equivalent and are trivially constructibe from each other.
// This is unfortunate, but we have to work around it because some LLVM builtins for some
// operations (e.g., mma) may return the former type.
// For now we have to explicitly specialize for each vector size we need. These are used
// in f6_pk_t below.
/// @brief scalar_type trait override for uint32_t vector of size 3
template <>
struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 3)))>
{
using type = uint32_t;
static constexpr index_t vector_size = 3;
};
/// @brief scalar_type trait override for uint32_t vector of size 6
template <>
struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 6)))>
{
using type = uint32_t;
static constexpr index_t vector_size = 6;
};
template <typename BitType, index_t pk_size>
struct f6_pk_t
{
@@ -105,16 +166,36 @@ struct f6_pk_t
using type = f6_pk_t<BitType, packed_size>;
/** This class may trivially constructed by the following vector type alias
* for example from a result of an mma operation. This is primarily for internal use.
* @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from
* uint32_t vectors of size 3 and 6 respectively for example from mma operation results.
* Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a
* NativeVectorT<uint32_t, 6> is NOT the same as __attribute__((__vector_size__(6 *
* sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being
* functionally equivalent. This class may be trivially constructed from both, so we can steer
* the templated ctor below to only consider incoming vectors types other than our two storage
* types of interest.
*/
using storage_type_alias =
element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size)));
__host__ __device__ constexpr f6_pk_t() {}
__host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init}
{
// TODO: consider removing initialization similar to vector_type<T, 256>
}
// Initialize from a vector type with the same size as packed_size
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
// Initialize from a vector type with the same size as packed_size.
// Exclude storage_type and storage_type_alias because these are trivially constructible.
template <
typename T,
typename = enable_if_t<!is_same_v<T, storage_type> && !is_same_v<T, storage_type_alias> &&
scalar_type<T>::vector_size == packed_size>>
__host__ __device__ f6_pk_t(const T& v)
{
static_assert(scalar_type<T>::vector_size == packed_size,
"Input vector size must match packed_size.");
static_for<0, packed_size, 1>{}(
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
}
@@ -202,17 +283,6 @@ struct pk_i4_t
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value ||
is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value ||
is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value;
}
// is_scalar_type
template <typename TV>
struct is_scalar_type
@@ -369,18 +439,6 @@ struct scalar_type<bf6x16_pk_t>
static constexpr index_t vector_size = 1;
};
/**
* @brief scalar_type trait override for NativeVectorT
* @tparam T The vector type
* @tparam Rank The number of elements in the vector
*/
template <typename T, index_t Rank>
struct scalar_type<NativeVectorT<T, Rank>>
{
using type = T;
static constexpr index_t vector_size = Rank;
};
template <typename T>
struct packed_type_info
{