[rocm-libraries] ROCm/rocm-libraries#5387 (commit 0c259bd)

[CK][CK Tile] Grouped Convolution Backward Weight set of
 fixes (#5387)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Grouped Convolution Backward Weight split k fixes for CK tile kernels

## Technical Details

- get k batch from kargs to get deduced k batch
- multiply zeroing size by data type size
- disable v6 (producing a incorrect results)

## Test Plan

test_grouped_convnd_bwd_weight_tile

## Test Result

Pass

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-13 16:19:50 +00:00
committed by assistant-librarian[bot]
parent 574c1c121a
commit b8108662da
6 changed files with 28 additions and 11 deletions

View File

@@ -73,7 +73,7 @@ void run_cpu_validation(const ckt::Args<SIGNATURE>& args,
template <auto SIGNATURE>
std::tuple<double, double>
get_rtol_atol(const int num_accums, const int num_accums_split_k, const float max_accumulated_value)
get_rtol_atol(const int num_accums, const int k_batch, const float max_accumulated_value)
{
using WeiDataType =
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
@@ -84,6 +84,8 @@ get_rtol_atol(const int num_accums, const int num_accums_split_k, const float ma
using ComputeType = WeiDataType;
using AccDataType = float;
// Assign middle value of the range for auto deduce
const int num_accums_split_k = k_batch > 0 ? k_batch : 64;
auto rtol = ck_tile::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
auto atol = ck_tile::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
@@ -150,14 +152,17 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
auto run_alg = [&](auto&& run_alg_func) {
for(auto& k_batch : split_k_values)
{
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
ckt::Args<SIGNATURE> args_k_batch = args;
args_k_batch.k_batch = k_batch;
std::tie(is_supported, avg_time, op_name) =
run_alg_func(args_k_batch, inputs, outputs, s_conf);
if(is_supported)
{
ckt::ValidationReport report;
auto&& [rtol, atol] =
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
ckt::Outputs<SIGNATURE>::reflect(
args,
args_k_batch,
[&](std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
@@ -182,7 +187,7 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
<< " Is all zero:" << error.is_all_zero()
<< " max err: " << error.max_error << std::endl;
// Check with cpu verification to get a values
run_cpu_validation<SIGNATURE>(args, outputs, reference.get());
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
}
all_instances_valid = false;
}