[rocm-libraries] ROCm/rocm-libraries#5387 (commit 0c259bd)

[CK][CK Tile] Grouped Convolution Backward Weight set of
 fixes (#5387)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Grouped Convolution Backward Weight split k fixes for CK tile kernels

## Technical Details

- get k batch from kargs to get deduced k batch
- multiply zeroing size by data type size
- disable v6 (producing a incorrect results)

## Test Plan

test_grouped_convnd_bwd_weight_tile

## Test Result

Pass

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-13 16:19:50 +00:00
committed by assistant-librarian[bot]
parent 574c1c121a
commit b8108662da
6 changed files with 28 additions and 11 deletions

View File

@@ -6,6 +6,7 @@
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
@@ -56,6 +57,7 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
if(!Conv::IsSupportedArgument(kargs))
return RunResult::not_supported("unsupported ck_tile arguments");
using Types = ck_tile::builder::factory::internal::TileConvTensorTypes<SIGNATURE.data_type>;
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
std::end(kargs.wei_g_k_c_xs_lengths.data),
1,
@@ -64,10 +66,13 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(args.k_batch > 1)
if(kargs.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(kargs.wei_ptr, 0, zeroing_size, s_conf.stream_id_));
hipMemsetAsync(kargs.wei_ptr,
0,
zeroing_size * sizeof(typename Types::EDataType),
s_conf.stream_id_));
}
}
};
@@ -156,7 +161,7 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(args.k_batch > 1)
if(kargs.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,

View File

@@ -434,7 +434,9 @@ def parse_bwd_weight_instances(instances, problem_name):
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")
continue
if pipeline_version == "V6":
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
continue
conv = ConvInstanceTemplateParams(
spec,