mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Merging the gfx12 code into public repo. (#1362)
This commit is contained in:
@@ -11,12 +11,17 @@ namespace ck {
|
||||
|
||||
enum struct WmmaInstr
|
||||
{
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16 = 0,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
wmma_f16_16x16x16_f16,
|
||||
wmma_bf16_16x16x16_bf16,
|
||||
wmma_i32_16x16x16_iu8,
|
||||
wmma_i32_16x16x16_iu4
|
||||
wmma_i32_16x16x16_iu4,
|
||||
// gfx12
|
||||
wmma_f32_16x16x16_f16_gfx12,
|
||||
wmma_f32_16x16x16_bf16_gfx12,
|
||||
wmma_i32_16x16x16_iu8_gfx12,
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
}
|
||||
};
|
||||
|
||||
// gfx12
|
||||
|
||||
// A-swizzled
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
// * Data Pixel
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
// static constexpr index_t acc_data_size = 4;
|
||||
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
|
||||
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
|
||||
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
|
||||
// * num_acc_vgprs_per_wave alone M direction
|
||||
// * num_subgroups alone M direction
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = false,
|
||||
bool neg_b = false,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
||||
a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_type_a,
|
||||
typename src_type_b,
|
||||
typename dst_type,
|
||||
@@ -296,13 +417,21 @@ struct WmmaSelector
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
#else
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
#else
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -320,8 +449,13 @@ struct WmmaSelector
|
||||
template <>
|
||||
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
#else
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
||||
@@ -502,6 +636,9 @@ struct WmmaGemm
|
||||
|
||||
__device__ static auto GetSubGroupId()
|
||||
{
|
||||
static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
|
||||
wmma_instr.wave_size,
|
||||
"");
|
||||
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
|
||||
}
|
||||
|
||||
@@ -516,12 +653,20 @@ struct WmmaGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return GetLaneIdUnderSubGroup();
|
||||
#else
|
||||
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return GetLaneIdUnderSubGroup();
|
||||
#else
|
||||
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
|
||||
Reference in New Issue
Block a user