mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
MX GEMM examples with FP8, FP16, and E8M0 scales (#2016)
* Add `scalar_type` specification for E8M0 exponent
* Specialize `nnvb_data_t_selector` for E8M0 exponent
* Remove partial specializations for `scalar_type` of `non_native_vector_base` template
* Reword command line helper string
* Create MX GEMM examples for different scales
[ROCm/composable_kernel commit: 72d888821c]
This commit is contained in:
committed by
GitHub
parent
21af4139ad
commit
75ef4c83bf
@@ -463,6 +463,13 @@ struct scalar_type<bf8_ocp_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<e8m0_bexp_t>
|
||||
{
|
||||
using type = e8m0_bexp_t::type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bool>
|
||||
{
|
||||
|
||||
@@ -1212,6 +1212,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
|
||||
using type = bf8_ocp_t::data_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<e8m0_bexp_t>
|
||||
{
|
||||
using type = e8m0_bexp_t::type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f6x16_pk_t>
|
||||
{
|
||||
@@ -1400,29 +1406,9 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<non_native_vector_base<T, N>>;
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
|
||||
struct scalar_type<non_native_vector_base<T, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
|
||||
|
||||
using type = typename non_native_vector_base<T, N>::data_t;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user