mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add new instances and support for small cases in DPP8 GEMM (#896)
This commit is contained in:
committed by
GitHub
parent
85e2e1e2e2
commit
547dbcfbc2
@@ -11,8 +11,14 @@ namespace ck {
|
||||
|
||||
enum struct DppInstr
|
||||
{
|
||||
dpp8_f16_16x16x2 = 0,
|
||||
dpp8_f16_1x32x2 = 0,
|
||||
dpp8_f16_2x16x2,
|
||||
dpp8_f16_2x32x2,
|
||||
dpp8_f16_4x16x2,
|
||||
dpp8_f16_4x32x2,
|
||||
dpp8_f16_8x16x2,
|
||||
dpp8_f16_8x32x2,
|
||||
dpp8_f16_16x16x2,
|
||||
dpp8_f16_32x8x2
|
||||
};
|
||||
|
||||
@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 8;
|
||||
static constexpr index_t n_per_wave = 16;
|
||||
static constexpr index_t m_per_lanegroup = 4;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 4;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
|
||||
{
|
||||
@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 4;
|
||||
static constexpr index_t n_per_wave = 32;
|
||||
static constexpr index_t m_per_lanegroup = 4;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 4;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 4;
|
||||
static constexpr index_t n_per_wave = 16;
|
||||
static constexpr index_t m_per_lanegroup = 2;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 2;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 1;
|
||||
static constexpr index_t n_per_wave = 32;
|
||||
static constexpr index_t m_per_lanegroup = 1;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 1;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 2;
|
||||
static constexpr index_t n_per_wave = 32;
|
||||
static constexpr index_t m_per_lanegroup = 2;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 2;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 2;
|
||||
static constexpr index_t n_per_wave = 16;
|
||||
static constexpr index_t m_per_lanegroup = 1;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 1;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
|
||||
struct DppSelector
|
||||
{
|
||||
@@ -143,6 +329,12 @@ struct DppSelector
|
||||
return DppInstr::dpp8_f16_8x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 8, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_8x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 16, 16>()
|
||||
{
|
||||
@@ -155,6 +347,36 @@ struct DppSelector
|
||||
return DppInstr::dpp8_f16_32x8x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 1, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_1x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 2, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_2x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 2, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_2x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 4, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_4x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 4, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_4x32x2;
|
||||
}
|
||||
|
||||
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
|
||||
|
||||
__host__ __device__ constexpr DppSelector()
|
||||
@@ -191,7 +413,6 @@ struct DppSelector
|
||||
// in the future when the implementation is more generalized.
|
||||
static_assert(selected_dpp.share_a);
|
||||
static_assert(selected_dpp.n_per_thread == 1);
|
||||
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
|
||||
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
|
||||
static_assert(selected_dpp.n_per_lanegroup ==
|
||||
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
|
||||
@@ -215,11 +436,6 @@ struct DppGemm
|
||||
|
||||
__host__ __device__ constexpr DppGemm()
|
||||
{
|
||||
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
|
||||
"MPerDpp must be either 8, 16 or 32.");
|
||||
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
|
||||
"NPerDpp must be either 8, 16 or 32.");
|
||||
|
||||
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user