Disable bwd weight split-k autodeduce for single stage kernels (#2856)

* Disable bwd weight split-k autodeduce for single stage kernels

* update interface tests

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: 29446da1d5]
This commit is contained in:
Bartłomiej Kocot
2025-09-19 16:27:50 +02:00
committed by GitHub
parent 240de6ee26
commit 38e1718bda
7 changed files with 96 additions and 33 deletions

View File

@@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
// clang-format on
ck::utils::conv::ConvParam conv_param;
std::vector<ck::index_t> split_ks{-1, 2};
ck::index_t split_k_ = 2;
template <ck::index_t NDimSpatial>
bool Run()
@@ -96,30 +96,24 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto conv = GroupedConvBwdWeightDeviceInstance{};
bool is_supported = true;
for(const auto split_k : split_ks)
{
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k);
is_supported &= conv.IsSupportedArgument(argument);
}
return is_supported;
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k_);
return conv.IsSupportedArgument(argument);
}
};
@@ -183,3 +177,12 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck)
is_supported = this->template Run<2>();
EXPECT_FALSE(is_supported);
}
TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce)
{
// Supported version but with auto deduce and single stage
this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
this->split_k_ = -1;
bool is_supported = this->template Run<2>();
EXPECT_FALSE(is_supported);
}