mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fixed handling of split-K autodeduce argument for grouped convolution (#3024)
* Fix handling of split-K autodeduce argument.
* Fix clang formatting.
* Test fix.
* Fix clang formatting.
[ROCm/composable_kernel commit: 7e44b845b5]
This commit is contained in:
@@ -689,6 +689,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split-K autodeduction is not supported
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
|
||||
@@ -1523,6 +1523,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Split-K autodeduction is not supported.
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
|
||||
@@ -688,6 +688,12 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Split-K autodeduction is not supported
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -47,10 +47,11 @@ class TestGroupedConvndBwdData : public ::testing::Test
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
ck::index_t split_k{1};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
@@ -112,7 +113,8 @@ class TestGroupedConvndBwdData : public ::testing::Test
|
||||
input_right_pads,
|
||||
Pass{},
|
||||
Pass{},
|
||||
Pass{});
|
||||
Pass{},
|
||||
split_k);
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
@@ -176,3 +178,24 @@ TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck)
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefault, SplitK)
|
||||
{
|
||||
if(ck::is_xdl_supported())
|
||||
{
|
||||
// SplitK = 1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
this->split_k = 1;
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
|
||||
// Split-K autodeduce
|
||||
this->split_k = -1;
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
else
|
||||
{
|
||||
GTEST_SKIP() << "XDL ops not supported on this device";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user