Fix bug and disable splitK=-1 tests for wmma

This commit is contained in:
Enrico Degregori
2025-08-07 07:27:11 +00:00
parent 37b6d28dc0
commit 9dbbb07953
2 changed files with 13 additions and 13 deletions

View File

@@ -447,6 +447,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(WeiDataType);
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -629,6 +634,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -757,12 +763,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
hip_check_error(hipMemsetAsync(
p_e_grid, 0, arg.GetWorkspaceETensorSizeBytes(), stream_config.stream_id_));
}
hip_check_error(
hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
};
const auto Run = [&](const auto& kernel) {
@@ -1047,13 +1049,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// workaround: disable when K, C is even
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
{
return false;
}
#endif
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{

View File

@@ -44,6 +44,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
}
}
if((split_k < 1) && (ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return true;
}
return false;
}