mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add grouped conv bwd weight multi d kernel (#1237)
* Add grouped conv bwd weight multi d kernel * Reference fix * Fix cmake files * bwd weight scale only xdl * Fixes * Fix client conv fwd example
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DsDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedConvBwdWeightMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsLayout::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -137,34 +137,6 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
// 1d
|
||||
static constexpr bool is_NWGK_GKXC_NWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
|
||||
static constexpr bool is_GNWK_GKXC_GNWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
|
||||
// 2d
|
||||
static constexpr bool is_NHWGK_GKYXC_NHWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
|
||||
static constexpr bool is_GNHWK_GKYXC_GNHWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
|
||||
// 3d
|
||||
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
|
||||
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
@@ -1065,9 +1037,15 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
if(arg.k_batch_ != 1)
|
||||
return false;
|
||||
|
||||
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
|
||||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
|
||||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
|
||||
if constexpr(!((NDimSpatial == 1 &&
|
||||
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(NDimSpatial == 2 &&
|
||||
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(NDimSpatial == 3 &&
|
||||
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -90,16 +90,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// 3d
|
||||
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
|
||||
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -218,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
|
||||
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
|
||||
@@ -720,7 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// 1d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NWGK_GKXC_NWGC()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNWK_GKXC_GNWC()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
|
||||
}
|
||||
// 2d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NHWGK_GKYXC_NHWGC()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNHWK_GKYXC_GNHWC()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
|
||||
}
|
||||
// 3d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
|
||||
}
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user