Add grouped gemm instances for RDNA4 (#3237)

* wip: grouped_gemm implementation based on wmma kernel + example for fp16

* chore: clean up grouped_gem_wmma_splitk_fp16 example

* chore: add cmake options to fully disable XDL or WMMA kernels

* feat: add tests for grouped gemma wmma instances for f16 and bf16 (all layouts)

* chore: add grouped gemm wmma bf16 example

* refactor: reuse more code between instance factory functions

* chore: turn test failure if not all batch sizes are supported into a warning

* chore: made failing of test on unsupported instances conditional to not break old tests

* chore: add log message to failure case where AK1/BK1/KBatch is too high for K value

* fix: issue with new overloads of GridwiseGemm_wmma_cshuffle_v3::Run()

* fix: stray comma after parameter list

* fix: compilation issues on RDNA3 and tests failing due to unsupported problems still being ran

* chore: update copyright in header comments

* nit: minor feebdack

* refactor: unified XDL / wma tests

* fix: properly disable FP8 instances when ONLY targeting gfx11

* refactor: add v3 suffix to grouped_gemm device struct name

* fix: small typos in example code

* fix: fully exclude xdl/wmma instances when using the corresponding cmake flags

* chore: remove unused destructor and added pipeline support checks to remove unnecessary paths

* fix: make sure to not add instance library to group if library was skipped

* fix: make sure xdl grouped gemm doesnt fail the new test

* fix: explicitly exclude test if no xdl/wmma support, as pattern matching fails in this case

* fix: examples not working since dependent types and functions were moved to ck namespace in develop

* fix: tests failing when compiling for just gfx11 due to trying to run unsupported instances

* chore: replace/add copyright headers with new format
This commit is contained in:
Erwin Terpstra
2025-12-02 00:32:10 +01:00
committed by GitHub
parent 23fb253c4e
commit 46f1d740f0
30 changed files with 2291 additions and 268 deletions

View File

@@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
@@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
const Block2CTileMap& block_2_ctile_map,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
@@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
@@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_m_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
const index_t block_n_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
@@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
epilogue_args);
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
problem,
DefaultBlock2CTileMap(problem),
a_element_op,
b_element_op,
cde_element_op,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
const Block2CTileMap& block_2_ctile_map,
EpilogueArgument& epilogue_args)
{
// shift A matrices pointer for splitk
@@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument,
BlockMapMBlockIndex,
BlockMapNBlockIndex>(p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,
block_2_ctile_map,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args);
}
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
{
return Block2CTileMap{problem.M, problem.N, 4};
}
};

View File

@@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: "
<< AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch
<< ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}