Fix grouped conv wrw kernels names (#3494)

This commit is contained in:
Bartłomiej Kocot
2025-12-30 16:45:39 +01:00
committed by GitHub
parent 53a1e4f551
commit 2b8302eb6d
2 changed files with 330 additions and 306 deletions

View File

@@ -48,7 +48,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3(
kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
@@ -468,7 +468,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -916,7 +916,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
{
if(gemm_arg.KBatch > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -931,7 +931,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -957,7 +957,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
{
if(gemm_arg.KBatch > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -972,7 +972,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,

View File

@@ -48,7 +48,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
@@ -106,7 +106,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds(
typename GridwiseGemm::Argument karg,
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
@@ -532,7 +532,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -549,7 +549,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -997,7 +997,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(gemm_arg.KBatch > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -1012,7 +1012,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -1033,43 +1033,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -1080,7 +1045,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
@@ -1090,18 +1093,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
@@ -1111,18 +1115,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
@@ -1132,18 +1137,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
@@ -1152,18 +1158,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
@@ -1173,18 +1180,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
@@ -1193,43 +1201,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -1240,7 +1213,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
@@ -1250,18 +1261,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
@@ -1271,18 +1283,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
@@ -1292,18 +1305,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
@@ -1312,18 +1326,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
@@ -1333,18 +1348,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
@@ -1357,34 +1373,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
@@ -1392,34 +1410,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
@@ -1430,34 +1450,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
@@ -1465,34 +1487,36 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
const auto kernel =
kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
@@ -1505,7 +1529,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
if(gemm_arg.KBatch > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
@@ -1520,7 +1544,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_two_stage<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,