diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 383b872832..3d6f34f121 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1671,7 +1671,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } - else + } + else + { + if constexpr(NXdlPerWave32 > 0) { if(!GridwiseGemmCTranspose32::CheckValidity( arg.a_grid_desc_m_k_container_[i], @@ -1686,10 +1689,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } - if(!valid) - { - return false; - } + } + if(!valid) + { + return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index a97e4503a8..1d9b7eb978 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -561,9 +561,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return; } - const index_t num_k_per_block = + const index_t num_ak0_per_block = __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); - + const index_t num_bk0_per_block = + __builtin_amdgcn_readfirstlane(b_grid_desc_bk0_n_bk1.GetLength(I0) / k_batch); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -605,7 +606,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, - make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0), + make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -636,7 +637,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, - make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0), + make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0),