[rocm-libraries] ROCm/rocm-libraries#5555 (commit 1d2c4c8)

[CK][CK Tile] Fix kbatch check in grouped conv and gemm
 kernels (#5555)

## Motivation

Fix kbatch check in grouped conv and gemm kernels, allow tails for
kbatch.

## Technical Details

Round up K / Kperxdl and divide it by Kbatch to allow tail for K.

## Test Plan

test_grouped_convnd_bwd_weight_tile

## Test Result

passed locally

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-21 22:56:19 +00:00
committed by assistant-librarian[bot]
parent 6b69ac9676
commit f79926009b
4 changed files with 10 additions and 7 deletions

View File

@@ -418,7 +418,8 @@ struct UniversalGemmKernel
}
}
if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
if(integer_divide_ceil(kargs.K, GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})) <
kargs.k_batch)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{

View File

@@ -574,7 +574,9 @@ struct GroupedConvolutionBackwardWeightKernel
}
}
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
if(integer_divide_ceil(kargs.GemmK,
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
kargs.k_batch)
{
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
kargs.GemmK,

View File

@@ -178,11 +178,11 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
});
const bool valid = report.get_errors().empty();
best_avg_time = std::min(best_avg_time, avg_time);
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
best_split_k = best_avg_time < avg_time ? best_split_k : k_batch;
if(valid)
{
best_avg_time = std::min(best_avg_time, avg_time);
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
best_split_k = best_avg_time < avg_time ? best_split_k : k_batch;
std::cout << "[Valid] Perf: " << std::setw(10) << avg_time << " ms," << " "
<< op_name << ", SplitK " << k_batch << std::endl;
}

View File

@@ -219,12 +219,12 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
tensor_layout::convolution::NHWGK>::type;
// k_batch = 128 should pass
auto host_args_kbatch_6 = create_2d_host_args(6);
auto host_args_kbatch_6 = create_2d_host_args(7);
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
// k_batch = 129 should fail for half_t output
auto host_args_kbatch_7 = create_2d_host_args(7);
auto host_args_kbatch_7 = create_2d_host_args(8);
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
}