diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 46140ac0c2..a47a2f90cc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -732,7 +732,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + false, // Don't modify KBatch dimension + false, // Don't modify KBatch dimension + true); // use_full_batch_kindex: keep full KBatch*K0 dimension a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 1286681a5b..cbca852e03 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -693,15 +693,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0; - const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1; - const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1; + // When hack is enabled, use GetElementSpaceSize() divided by k_batch for buffer size. + // This matches the stride calculation in the device layer and correctly accounts for + // the memory layout encoded in GetElementSpaceSize(). + const long_index_t a_buffer_size = + split_k_offset_a_hack ? (a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / k_batch) + : a_b_k0_m_k1_grid_desc.GetElementSpaceSize(); + + const long_index_t b_buffer_size = + split_k_offset_b_hack ? (b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / k_batch) + : b_b_k0_n_k1_grid_desc.GetElementSpaceSize(); const auto a_grid_buf = make_dynamic_buffer( - p_a_grid + split_k_offset_a, - a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / a_space_size_divisor); + p_a_grid + split_k_offset_a, a_buffer_size); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + split_k_offset_b, - b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / b_space_size_divisor); + p_b_grid + split_k_offset_b, b_buffer_size); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());