[Conv] Enable bwd weight splitk autodeduction with cap (#3656)

* Enable bwd weight splitk autodeduction with cap

* Fix error threshold calculations

* Add missing logic to wmma multiple d kernel

* Fix threshold calculation

* Update test with new applicability
This commit is contained in:
Johannes Graner
2026-01-29 18:40:28 +01:00
committed by GitHub
parent e33f15709f
commit fabac7e2c3
10 changed files with 91 additions and 76 deletions

View File

@@ -11,8 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -11,8 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -162,7 +162,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
}
else
{
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
@@ -171,9 +170,11 @@ struct DeviceGroupedConvBwdWeight_Explicit
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -338,16 +339,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if constexpr(!IsTwoStageNeeded)
{
if(arg.k_batch_ < 0)
{
return false;
}
}
#endif
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())

View File

@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return;
}
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// TODO: implement
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(
@@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -594,7 +594,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -611,6 +610,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -620,7 +622,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1399,13 +1400,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
// check device
if constexpr(DirectLoad)
{