[CK_TILE] Minor splitk bugfix for gemms and conv (#3387)

* fix for splitk if splitk < grid

* add different splitk implementation

* minor bugfix for streamk gemm

* Add test

---------

Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
jakpiase
2025-12-24 00:10:13 +01:00
committed by GitHub
parent e1381d6a71
commit c0797c1671
3 changed files with 80 additions and 13 deletions

View File

@@ -323,22 +323,38 @@ struct UniversalGemmKernel
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
// This structure distributes work evenly among splitkk workgroups
// It's based on a principle that if there is enough work to fill all workgroups,
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
// and leave the potential tail for last(splitk - 1) indexed workgroup.
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
{
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t num_all = amd_wave_read_first_lane(
kargs.K / K1); // num of all loops not including potential tail
index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch);
num_full = num_full == 0 ? kargs.k_batch : num_full;
const index_t num_full_iters =
amd_wave_read_first_lane(std::max(integer_divide_ceil(num_all, kargs.k_batch), 1));
const index_t full_k_read = num_full_iters * K1;
const index_t partial_k_read = (num_full_iters - 1) * K1;
static_for<0, NumATensor, 1>{}([&](auto index) {
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
{
as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
as_k_split_offset[index] =
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
std::max(k_id - num_full, 0) * partial_k_read);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
{
as_k_split_offset[index] =
amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
std::max(k_id - num_full, 0) * partial_k_read) *
kargs.stride_As[index]);
}
});
@@ -347,21 +363,30 @@ struct UniversalGemmKernel
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
{
bs_k_split_offset[index] =
amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
std::max(k_id - num_full, 0) * partial_k_read) *
kargs.stride_Bs[index]);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
{
bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
bs_k_split_offset[index] =
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
std::max(k_id - num_full, 0) * partial_k_read);
}
});
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
if(k_id == kargs.k_batch - 1)
{
splitted_k = amd_wave_read_first_lane(KRead);
splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read -
std::max(k_id - num_full, 0) * partial_k_read;
}
else if(k_id < num_full)
{
splitted_k = full_k_read;
}
else
{
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
splitted_k = partial_k_read;
}
}
@@ -385,6 +410,15 @@ struct UniversalGemmKernel
}
}
if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
}
return false;
}
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
: GemmPipeline::template GetVectorSizeA<false>();
bool AsTesnorIsValid = {true};

View File

@@ -568,6 +568,15 @@ struct GroupedConvolutionBackwardWeightKernel
}
}
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
}
return false;
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];

View File

@@ -173,6 +173,11 @@ static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
}
static GroupedConvBwdWeightHostArgs create_large_2d_host_args(index_t k_batch)
{
return create_2d_host_args(2, 2, 8, 8, 3, 3, 70, 70, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
}
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
{
};
@@ -227,6 +232,25 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreat
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK>::type;
// k_batch = 128 should pass
auto host_args_kbatch_6 = create_2d_host_args(6);
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
// k_batch = 129 should fail for half_t output
auto host_args_kbatch_7 = create_2d_host_args(7);
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
{
using Kernel = typename BuildKernel<half_t,
@@ -236,13 +260,13 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKB
tensor_layout::convolution::NHWGK>::type;
// k_batch = 128 should pass
auto host_args_kbatch_128 = create_2d_host_args(128);
auto host_args_kbatch_128 = create_large_2d_host_args(128);
auto kargs_128 =
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
// k_batch = 129 should fail for half_t output
auto host_args_kbatch_129 = create_2d_host_args(129);
auto host_args_kbatch_129 = create_large_2d_host_args(129);
auto kargs_129 =
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));