mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] Minor splitk bugfix for gemms and conv (#3387)
* fix for splitk if splitk < grid * add different splitk implementation * minor bugfix for streamk gemm * Add test --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -173,6 +173,11 @@ static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
|
||||
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
|
||||
}
|
||||
|
||||
static GroupedConvBwdWeightHostArgs create_large_2d_host_args(index_t k_batch)
|
||||
{
|
||||
return create_2d_host_args(2, 2, 8, 8, 3, 3, 70, 70, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
|
||||
}
|
||||
|
||||
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
@@ -227,6 +232,25 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreat
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
// k_batch = 128 should pass
|
||||
auto host_args_kbatch_6 = create_2d_host_args(6);
|
||||
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
|
||||
|
||||
// k_batch = 129 should fail for half_t output
|
||||
auto host_args_kbatch_7 = create_2d_host_args(7);
|
||||
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
@@ -236,13 +260,13 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKB
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
// k_batch = 128 should pass
|
||||
auto host_args_kbatch_128 = create_2d_host_args(128);
|
||||
auto host_args_kbatch_128 = create_large_2d_host_args(128);
|
||||
auto kargs_128 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
|
||||
|
||||
// k_batch = 129 should fail for half_t output
|
||||
auto host_args_kbatch_129 = create_2d_host_args(129);
|
||||
auto host_args_kbatch_129 = create_large_2d_host_args(129);
|
||||
auto kargs_129 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));
|
||||
|
||||
Reference in New Issue
Block a user