MX GEMM examples with FP8, FP16, and E8M0 scales (#2016)

* Add `scalar_type` specification for E8M0 exponent

* Specialize `nnvb_data_t_selector` for E8M0 exponent

* Remove partial specializations for `scalar_type` of `non_native_vector_base` template

* Reword command line helper string

* Create MX GEMM examples for different scales


[ROCm/composable_kernel commit: 72d888821c]
This commit is contained in:
Andriy Roshchenko
2025-03-25 15:33:03 -06:00
committed by GitHub
parent 21af4139ad
commit 75ef4c83bf
8 changed files with 140 additions and 41 deletions

View File

@@ -463,6 +463,13 @@ struct scalar_type<bf8_ocp_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bool>
{

View File

@@ -1212,6 +1212,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
};
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
@@ -1400,29 +1406,9 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
struct scalar_type<non_native_vector_base<T, N>>
{
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
using type = typename non_native_vector_base<T, N>::data_t;
static constexpr index_t vector_size = N;
};