From 0ecba120e09ddb0c449ff70c341681190a78672a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 30 Dec 2025 16:45:39 +0100 Subject: [PATCH] Fix grouped conv wrw kernels names (#3494) [ROCm/composable_kernel commit: 2b8302eb6d2217c0f537c28538265f4003ec416e] --- ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 12 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 624 +++++++++--------- 2 files changed, 330 insertions(+), 306 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 37fe0b2c7b..ab43430512 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -48,7 +48,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage( typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -468,7 +468,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -916,7 +916,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -931,7 +931,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -957,7 +957,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -972,7 +972,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index e975534a06..97a632664c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -48,7 +48,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage( typename GridwiseGemm::Argument karg, [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -106,7 +106,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds( typename GridwiseGemm::Argument karg, [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -532,7 +532,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -549,7 +549,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -997,7 +997,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -1012,7 +1012,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } else { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -1033,43 +1033,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -1080,7 +1045,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, - TailNumber::Two>; + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; Run(kernel); } } @@ -1090,18 +1093,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Three>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; Run(kernel); } } @@ -1111,18 +1115,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Four>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; Run(kernel); } } @@ -1132,18 +1137,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Five>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; Run(kernel); } } @@ -1152,18 +1158,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Six>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; Run(kernel); } } @@ -1173,18 +1180,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Seven>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; Run(kernel); } } @@ -1193,43 +1201,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -1240,7 +1213,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle true, InMemoryDataOperationEnum::Set, minimum_occupancy, - TailNumber::Two>; + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; Run(kernel); } } @@ -1250,18 +1261,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Three>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; Run(kernel); } } @@ -1271,18 +1283,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Four>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; Run(kernel); } } @@ -1292,18 +1305,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Five>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; Run(kernel); } } @@ -1312,18 +1326,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Six>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; Run(kernel); } } @@ -1333,18 +1348,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Seven>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; Run(kernel); } } @@ -1357,34 +1373,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -1392,34 +1410,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -1430,34 +1450,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -1465,34 +1487,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - NumGroupsToMerge, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -1505,7 +1529,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -1520,7 +1544,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } else { - const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage< GridwiseGemm, remove_reference_t, remove_reference_t,