diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index e91496f6a5..b2f1dbfa5c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -79,15 +79,12 @@ __global__ void [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, - [[maybe_unused]] const index_t groups_count) + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); @@ -141,15 +138,12 @@ __global__ void [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, - [[maybe_unused]] const index_t groups_count) + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); @@ -766,7 +760,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); - gdy *= arg.num_group_ * num_workgroups_per_Conv_N; + gdy = arg.num_group_; + gdz = num_workgroups_per_Conv_N; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); @@ -820,8 +815,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_, - arg.num_group_); + arg.compute_ptr_offset_of_n_); } else { @@ -836,8 +830,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_, - arg.num_group_); + arg.compute_ptr_offset_of_n_); } }; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp index 088fed89ff..d017a40bce 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -83,6 +83,9 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) // When image is larger than 2GB this->conv_params.push_back( {2, 2, 2, 128, 128, {3, 3}, {4096, 2048}, {300, 300}, {3, 3}, {1, 1}, {1, 1}}); + // Split N and G > 1 + this->conv_params.push_back( + {2, 4, 112, 8, 8, {3, 3}, {469, 724}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); this->template Run<2>(); }