diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index aa632eb3f6..23dcbb7486 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -779,24 +779,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const bool is_b_stride_divisible = b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0; + // Check if descriptor is contiguous (no padding in layout) + // For a contiguous (K0, M, K1) descriptor, ElementSpaceSize should equal K0*M*K1 + const long_index_t expected_size_a = static_cast( + a_grid_desc_k0_m_k1_.GetLength(I0) * a_grid_desc_k0_m_k1_.GetLength(I1) * + a_grid_desc_k0_m_k1_.GetLength(I2)); + const bool is_a_contiguous = + a_grid_desc_k0_m_k1_.GetElementSpaceSize() == expected_size_a; + + const long_index_t expected_size_b = static_cast( + b_grid_desc_k0_n_k1_.GetLength(I0) * b_grid_desc_k0_n_k1_.GetLength(I1) * + b_grid_desc_k0_n_k1_.GetLength(I2)); + const bool is_b_contiguous = + b_grid_desc_k0_n_k1_.GetElementSpaceSize() == expected_size_b; + // Determine if we can safely use the split-k offset hack + // The hack requires contiguous descriptor layout for correct stride calculation split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch && - is_k_not_paded && is_correct_layout && is_a_stride_divisible; + is_k_not_paded && is_correct_layout && is_a_stride_divisible && + is_a_contiguous; split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded && - is_correct_layout && is_b_stride_divisible; + is_correct_layout && is_b_stride_divisible && is_b_contiguous; - // Calculate stride from descriptor size - // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, - // so we need to divide by k_batch_ to get the per-batch stride when the hack is - // enabled. The divisibility check above ensures no truncation occurs. - split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + // Calculate stride for split-k offset hack + // The descriptor has shape (K0_full, M, K1) where K0_full = k_batch_ * K0_per_batch + // To advance from k_idx=0 to k_idx=1, we need to skip K0_per_batch slices + // Each K0 slice occupies M*K1 elements, so stride = (K0_full/k_batch_) * M * K1 if(split_k_offset_a_hack_) - split_k_stride_a_ /= k_batch_; + { + const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_; + split_k_stride_a_ = k0_per_batch * a_grid_desc_k0_m_k1_.GetLength(I1) * + a_grid_desc_k0_m_k1_.GetLength(I2); + } + else + { + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + } - split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); if(split_k_offset_b_hack_) - split_k_stride_b_ /= k_batch_; + { + const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_; + split_k_stride_b_ = k0_per_batch * b_grid_desc_k0_n_k1_.GetLength(I1) * + b_grid_desc_k0_n_k1_.GetLength(I2); + } + else + { + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + } const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);