Grouped conv bwd data NGCHW (#1967)

* Grouped conv bwd data NGCHW

* fixes

* fix

* Improvements

* Fix

* Fix

* add client example
This commit is contained in:
Bartłomiej Kocot
2025-03-17 13:32:00 +01:00
committed by GitHub
parent 52b1cd7780
commit c2e4898b4b
26 changed files with 1351 additions and 71 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -51,6 +51,9 @@ using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NGKHW, GKYXC, NGCHW>,
std::tuple<ck::half_t, NGKHW, GKYXC, NGCHW>,
std::tuple<ck::bhalf_t, NGKHW, GKYXC, NGCHW>,
std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
@@ -58,6 +61,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NGKDHW, GKZYXC, NGCDHW>,
std::tuple<ck::half_t, NGKDHW, GKZYXC, NGCDHW>,
std::tuple<ck::bhalf_t, NGKDHW, GKZYXC, NGCDHW>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;