Add gemm universal f8 f8 bf16 instances on gfx950 (#2662)

This commit is contained in:
jefyang1
2025-08-14 13:25:24 -07:00
committed by GitHub
parent 10395fc895
commit d7c95dd491
8 changed files with 174 additions and 22 deletions

View File

@@ -36,16 +36,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 on pre-gfx950
if constexpr(((GridwiseGemm::AK1Number >= 32 || GridwiseGemm::BK1Number >= 32) &&
static_cast<bool>(Arch::is_gfx950_build)) ||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32))
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
}
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -64,20 +78,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 on pre-gfx950
if constexpr(((GridwiseGemm::AK1Number >= 32 || GridwiseGemm::BK1Number >= 32) &&
static_cast<bool>(Arch::is_gfx950_build)) ||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32))
{
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared_0,
p_shared_1,
karg);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared_0,
p_shared_1,
karg);
}
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))