mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Grouped conv bwd data NGCHW (#1967)
* Grouped conv bwd data NGCHW * fixes * fix * Improvements * Fix * Fix * add client example
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -126,12 +126,13 @@ __global__ void
|
||||
OutDataTypePointerTuple p_out_global_with_offset_tuple;
|
||||
|
||||
static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
|
||||
p_in_global_with_offset_tuple(i) =
|
||||
p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
|
||||
});
|
||||
|
||||
static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_out_global_with_offset_tuple(i) =
|
||||
p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
|
||||
p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
|
||||
});
|
||||
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
|
||||
|
||||
Reference in New Issue
Block a user