[rocm-libraries] ROCm/rocm-libraries#4872 (commit ca623f7)

[CK] Small improvements for grouped conv backward weight
 (#4872)

## Motivation

Improvements for CK Tile convolution builder run function and atol/rtol
calculations.

## Technical Details

- Add preprocessing function for wrw when k_batch is larger than 1 for
builder run function
- Divide num acums by number of groups to get real number of accums

## Test Plan

CI wrw tests

## Test Result

pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-783
This commit is contained in:
Bartłomiej Kocot
2026-02-25 20:11:01 +00:00
committed by assistant-librarian[bot]
parent c90a363589
commit eede24de0d
5 changed files with 28 additions and 7 deletions

View File

@@ -53,11 +53,28 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
if(!Conv::IsSupportedArgument(kargs))
return RunResult::not_supported("unsupported ck_tile arguments");
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
std::end(kargs.wei_g_k_c_xs_lengths.data),
1,
std::multiplies<std::size_t>());
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(args.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(kargs.wei_ptr, 0, zeroing_size, s_conf.stream_id_));
}
}
};
constexpr index_t minimum_occupancy =
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
return RunResult::from_runtime(ck_tile::launch_kernel(
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
s_conf,
preprocess,
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
}
} // namespace detail