Grouped conv bwd data NGCHW (#1967)

* Grouped conv bwd data NGCHW

* fixes

* fix

* Improvements

* Fix

* Fix

* add client example
This commit is contained in:
Bartłomiej Kocot
2025-03-17 13:32:00 +01:00
committed by GitHub
parent 52b1cd7780
commit c2e4898b4b
26 changed files with 1351 additions and 71 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -126,12 +126,13 @@ __global__ void
OutDataTypePointerTuple p_out_global_with_offset_tuple;
static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
p_in_global_with_offset_tuple(i) =
p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
});
static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_out_global_with_offset_tuple(i) =
p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
});
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,