From 0d94859dca1c6e9d956fbd16a815c507cc58cfbc Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 24 Dec 2025 00:10:13 +0100 Subject: [PATCH] [CK_TILE] Minor splitk bugfix for gemms and conv (#3387) * fix for splitk if splitk < grid * add different splitk implementation * minor bugfix for streamk gemm * Add test --------- Co-authored-by: Bartlomiej Kocot [ROCm/composable_kernel commit: c0797c167143aa750936c108caa0945640eeefd1] --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 56 +++++++++++++++---- ...ped_convolution_backward_weight_kernel.hpp | 9 +++ .../test_ck_tile_grouped_conv_bwd_weight.cpp | 28 +++++++++- 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 5f7e78fac2..77952c9afd 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -323,22 +323,38 @@ struct UniversalGemmKernel struct SplitKBatchOffset { - __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + // This structure distributes work evenly among splitkk workgroups + // It's based on a principle that if there is enough work to fill all workgroups, + // then we can distribute the (K / K1) parts among k_batch workgroups in such a way + // that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1 + // and leave the potential tail for last(splitk - 1) indexed workgroup. + __device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z) { - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); - const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t num_all = amd_wave_read_first_lane( + kargs.K / K1); // num of all loops not including potential tail + index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch); + num_full = num_full == 0 ? kargs.k_batch : num_full; + + const index_t num_full_iters = + amd_wave_read_first_lane(std::max(integer_divide_ceil(num_all, kargs.k_batch), 1)); + const index_t full_k_read = num_full_iters * K1; + const index_t partial_k_read = (num_full_iters - 1) * K1; static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead); + as_k_split_offset[index] = + amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read + + std::max(k_id - num_full, 0) * partial_k_read); } else if constexpr(std::is_same_v) { as_k_split_offset[index] = - amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]); + amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read + + std::max(k_id - num_full, 0) * partial_k_read) * + kargs.stride_As[index]); } }); @@ -347,21 +363,30 @@ struct UniversalGemmKernel if constexpr(std::is_same_v) { bs_k_split_offset[index] = - amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]); + amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read + + std::max(k_id - num_full, 0) * partial_k_read) * + kargs.stride_Bs[index]); } else if constexpr(std::is_same_v) { - bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead); + bs_k_split_offset[index] = + amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read + + std::max(k_id - num_full, 0) * partial_k_read); } }); - if(k_id < static_cast(kargs.k_batch - 1)) + if(k_id == kargs.k_batch - 1) { - splitted_k = amd_wave_read_first_lane(KRead); + splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read - + std::max(k_id - num_full, 0) * partial_k_read; + } + else if(k_id < num_full) + { + splitted_k = full_k_read; } else { - splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); + splitted_k = partial_k_read; } } @@ -385,6 +410,15 @@ struct UniversalGemmKernel } } + if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!"); + } + return false; + } + const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() : GemmPipeline::template GetVectorSizeA(); bool AsTesnorIsValid = {true}; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6034dfc3de..4b7ad72ffc 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -568,6 +568,15 @@ struct GroupedConvolutionBackwardWeightKernel } } + if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!"); + } + return false; + } + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp index f37065f7c7..bdce90e385 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -173,6 +173,11 @@ static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch) return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch); } +static GroupedConvBwdWeightHostArgs create_large_2d_host_args(index_t k_batch) +{ + return create_2d_host_args(2, 2, 8, 8, 3, 3, 70, 70, 1, 1, 1, 1, 1, 1, 1, 1, k_batch); +} + class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test { }; @@ -227,6 +232,25 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreat EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2)); } +TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation) +{ + using Kernel = typename BuildKernel::type; + + // k_batch = 128 should pass + auto host_args_kbatch_6 = create_2d_host_args(6); + auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6)); + + // k_batch = 129 should fail for half_t output + auto host_args_kbatch_7 = create_2d_host_args(7); + auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7); + EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7)); +} + TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch) { using Kernel = typename BuildKernel::type; // k_batch = 128 should pass - auto host_args_kbatch_128 = create_2d_host_args(128); + auto host_args_kbatch_128 = create_large_2d_host_args(128); auto kargs_128 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128)); // k_batch = 129 should fail for half_t output - auto host_args_kbatch_129 = create_2d_host_args(129); + auto host_args_kbatch_129 = create_large_2d_host_args(129); auto kargs_129 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129); EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));