Disable bwd weight split-k autodeduce for single stage kernels (#2856)

* Disable bwd weight split-k autodeduce for single stage kernels

* update interface tests

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Bartłomiej Kocot
2025-09-19 16:27:50 +02:00
committed by GitHub
parent 6cf3fdd21c
commit 29446da1d5
7 changed files with 96 additions and 33 deletions

View File

@@ -11,6 +11,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -11,6 +11,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -144,18 +144,39 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if(split_k < 0)
if constexpr(IsTwoStageNeeded)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else
{
split_k_ = split_k;
}
}
else
{
split_k_ = split_k;
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else
#endif
{
split_k_ = split_k;
}
}
if constexpr(IsTwoStageNeeded)
@@ -318,6 +339,16 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if constexpr(!IsTwoStageNeeded)
{
if(arg.split_k_ < 0)
{
return false;
}
}
#endif
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())

View File

@@ -671,6 +671,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -683,6 +684,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
grid_size);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -939,6 +941,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -553,6 +553,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -565,6 +566,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
grid_size);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -934,6 +936,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -524,6 +524,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -549,6 +550,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1275,6 +1277,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
// clang-format on
ck::utils::conv::ConvParam conv_param;
std::vector<ck::index_t> split_ks{-1, 2};
ck::index_t split_k_ = 2;
template <ck::index_t NDimSpatial>
bool Run()
@@ -96,30 +96,24 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto conv = GroupedConvBwdWeightDeviceInstance{};
bool is_supported = true;
for(const auto split_k : split_ks)
{
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k);
is_supported &= conv.IsSupportedArgument(argument);
}
return is_supported;
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k_);
return conv.IsSupportedArgument(argument);
}
};
@@ -183,3 +177,12 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck)
is_supported = this->template Run<2>();
EXPECT_FALSE(is_supported);
}
TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce)
{
// Supported version but with auto deduce and single stage
this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
this->split_k_ = -1;
bool is_supported = this->template Run<2>();
EXPECT_FALSE(is_supported);
}