From dde91b60fbd806a0aea6e9f4a30ac36a5ec6aa9c Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Mon, 29 Sep 2025 22:56:33 +0800 Subject: [PATCH] [CK] Fix example_grouped_conv_bwd_data_xdl_fp16 with ksplit = 2 (#2943) root cause: AK1 and BK1 may different in class template. so we need calculate k0 per block separately when ksplit is not 1. [ROCm/composable_kernel commit: 769c58f13399403bbe22350eaddceb4a5fd38b3d] --- ...ped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 13 ++++++++----- .../grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 9 +++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 383b872832..3d6f34f121 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1671,7 +1671,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } - else + } + else + { + if constexpr(NXdlPerWave32 > 0) { if(!GridwiseGemmCTranspose32::CheckValidity( arg.a_grid_desc_m_k_container_[i], @@ -1686,10 +1689,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } - if(!valid) - { - return false; - } + } + if(!valid) + { + return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index a97e4503a8..1d9b7eb978 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -561,9 +561,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return; } - const index_t num_k_per_block = + const index_t num_ak0_per_block = __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); - + const index_t num_bk0_per_block = + __builtin_amdgcn_readfirstlane(b_grid_desc_bk0_n_bk1.GetLength(I0) / k_batch); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -605,7 +606,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, - make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0), + make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -636,7 +637,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, - make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0), + make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0),