[CK_TILE] Minor splitk bugfix for gemms and conv (#3387)

* fix for splitk if splitk < grid

* add different splitk implementation

* minor bugfix for streamk gemm

* Add test

---------

Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
jakpiase
2025-12-24 00:10:13 +01:00
committed by GitHub
parent e1381d6a71
commit c0797c1671
3 changed files with 80 additions and 13 deletions

View File

@@ -173,6 +173,11 @@ static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
}
static GroupedConvBwdWeightHostArgs create_large_2d_host_args(index_t k_batch)
{
return create_2d_host_args(2, 2, 8, 8, 3, 3, 70, 70, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
}
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
{
};
@@ -227,6 +232,25 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreat
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK>::type;
// k_batch = 128 should pass
auto host_args_kbatch_6 = create_2d_host_args(6);
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
// k_batch = 129 should fail for half_t output
auto host_args_kbatch_7 = create_2d_host_args(7);
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
{
using Kernel = typename BuildKernel<half_t,
@@ -236,13 +260,13 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKB
tensor_layout::convolution::NHWGK>::type;
// k_batch = 128 should pass
auto host_args_kbatch_128 = create_2d_host_args(128);
auto host_args_kbatch_128 = create_large_2d_host_args(128);
auto kargs_128 =
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
// k_batch = 129 should fail for half_t output
auto host_args_kbatch_129 = create_2d_host_args(129);
auto host_args_kbatch_129 = create_large_2d_host_args(129);
auto kargs_129 =
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));