diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f5d9a789fa..bb97d29534 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -86,7 +86,6 @@ __global__ void const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, - const index_t groups_count, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -101,10 +100,8 @@ __global__ void defined(__gfx94__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const long_index_t e_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); @@ -200,7 +197,6 @@ __global__ void ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = groups_count; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -321,8 +317,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + ADataType, + EDataType>; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; - const index_t gdz = 1; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count as_grid_desc_ak0_m_ak1, bs_grid_desc_bk0_n_bk1, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,