mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Fix bug and disable splitK=-1 tests for wmma
This commit is contained in:
@@ -447,6 +447,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(WeiDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -629,6 +634,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -757,12 +763,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_e_grid, 0, arg.GetWorkspaceETensorSizeBytes(), stream_config.stream_id_));
|
||||
}
|
||||
hip_check_error(
|
||||
hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
@@ -1047,13 +1049,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// workaround: disable when K, C is even
|
||||
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
|
||||
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
|
||||
@@ -44,6 +44,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
if((split_k < 1) && (ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user