[ck] add gridwise base class for in all xdl kernel (#186) (#3544)

1. Add base class GridwiseGemm_xdl_cshuffle_base for all gridwise_gemm_xdl classes.
- to select correct LDS layout and epilogue behavior , three additional parameters is added.
- ForceNaiveLdsLayout: disable XOR based LDS layout when it is true
- DirectLoad: pipeline only use directload, we need force naive layout and ignore any padding on gfx9
- IsMxGemm: epilogue has two addtional dimensions
2. Move all LDS descriptor layout related fucntion to base class, including
- GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
- GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
- GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
3. Move several LDS related helper funtions to base class, including
- GetSharedMemoryNumberOfByte
- GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1
- GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1
- GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
4. Move all c epilogue related code to base class, and 4 kind of implementation are provided
- RunEpilogueNoShuffle
- RunEpilogue
- RunMultiDEpilogue
- RunMoeEpilogue
This commit is contained in:
linqunAMD
2026-01-28 04:49:47 +08:00
committed by GitHub
parent b737f1dee5
commit 23cefda140
85 changed files with 9016 additions and 20081 deletions

View File

@@ -293,7 +293,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
// convolution forward. For some reason for that specific type there is an ambiguity
// in the type resolution for the ternary expression. I added an explicit cast to
// disambiguate and only use it for f8 just in case it affects performance.
if constexpr(std::is_same_v<scalar_t, ck::f8_ocp_t>)
if constexpr(is_same_v<scalar_t, ck::f8_ocp_t>)
{
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vector_t{elm_vectors(i).template AsType<elm_vector_t>()[I0]}