From 4f83a3d745190728545d52af2f031c519d9e0f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 21 Oct 2025 15:41:02 +0200 Subject: [PATCH] Gridwise gemm conv v3 force padded layout on gfx950 (#2961) * Gridwise gemm conv v3 force padded layout on gfx950 * fix bug in other gridwise * fix * Update gridwise_gemm_wmma_cshuffle_v3_common.hpp [ROCm/composable_kernel commit: 3a28632b203f9219ed4906d46457872ef1084054] --- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 20 +++++++++++++++---- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 2 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 20 ++++++++++++++++--- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 3940c42c20..60ad4651b6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -45,7 +45,7 @@ template {}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + make_tuple(Number{} * AK1Number, AK1Number, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. @@ -412,12 +418,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t WaveSize = BlockSize / (MWave * NWave); +#if defined(__gfx950__) + // Force use padded layout on gfx950 to reduce bank conflicts + constexpr index_t BBlockLdsExtraN = 1; +#else + constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom; +#endif // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + make_tuple(Number{} * BK1Number, BK1Number, I1)); } else if constexpr(is_same::value) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index a6e4870ac7..11b75a6541 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -828,7 +828,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 // loop to hide it in v4. it may give you some benefit from less valu in compute address return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); + make_tuple(Number{} * AK1Number, AK1Number, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 5b19ff8542..e2071e061d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -131,7 +131,7 @@ template {}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); + make_tuple(Number{} * AK1Number, AK1Number, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. @@ -840,6 +847,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t WaveSize = BlockSize / (MWave * NWave); +#if defined(__gfx950__) + // Force use padded layout on gfx950 to reduce bank conflicts + constexpr index_t BBlockLdsExtraN = 1; +#else + constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom; +#endif + // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) {