Add new instances and support for small cases in DPP8 GEMM (#896)

This commit is contained in:
Bartlomiej Wroblewski
2023-09-12 17:05:23 +02:00
committed by GitHub
parent 85e2e1e2e2
commit 547dbcfbc2
11 changed files with 562 additions and 23 deletions

View File

@@ -11,8 +11,14 @@ namespace ck {
enum struct DppInstr
{
dpp8_f16_16x16x2 = 0,
dpp8_f16_1x32x2 = 0,
dpp8_f16_2x16x2,
dpp8_f16_2x32x2,
dpp8_f16_4x16x2,
dpp8_f16_4x32x2,
dpp8_f16_8x16x2,
dpp8_f16_8x32x2,
dpp8_f16_16x16x2,
dpp8_f16_32x8x2
};
@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_8x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 8;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
{
@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 4;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 4;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_4x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 4;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_1x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 1;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x32x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 32;
static constexpr index_t m_per_lanegroup = 2;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 2;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <>
struct dpp_type<DppInstr::dpp8_f16_2x16x2>
{
static constexpr index_t wave_size = 32;
static constexpr index_t lanegroup_size = 8;
static constexpr index_t m_per_wave = 2;
static constexpr index_t n_per_wave = 16;
static constexpr index_t m_per_lanegroup = 1;
static constexpr index_t n_per_lanegroup = 8;
static constexpr index_t m_per_thread = 1;
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppLanegroupGemm<m_per_thread,
n_per_thread,
k_per_dpp,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector
{
@@ -143,6 +329,12 @@ struct DppSelector
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
{
@@ -155,6 +347,36 @@ struct DppSelector
return DppInstr::dpp8_f16_32x8x2;
}
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector()
@@ -191,7 +413,6 @@ struct DppSelector
// in the future when the implementation is more generalized.
static_assert(selected_dpp.share_a);
static_assert(selected_dpp.n_per_thread == 1);
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
static_assert(selected_dpp.n_per_lanegroup ==
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
@@ -215,11 +436,6 @@ struct DppGemm
__host__ __device__ constexpr DppGemm()
{
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
"MPerDpp must be either 8, 16 or 32.");
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
"NPerDpp must be either 8, 16 or 32.");
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
}