mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Add client example of grouped conv2d backward weight (data type: fp16) (#498)
* Remove redundant CMake setting * Extract common code from files * Rename folder 'convnd' to 'conv' * Use std::array<> to accept compile-time kwnown # of arguments * Fix compilation error of tuning parameter * In example, use same setting as unit-test * Remove no-longer used include directive * Add interface for grouped conv bwd weight * Add group support for conv bwd weight * Add grouped conv bwd weight example * Use group parameter in example * Rename example folder * Remove non-grouped version example source files * Rename device op template * Add group support to convolution backward weight * Remove debug messages * Use smaller group size in example * Use named variable as loop terminate condition * Prettify example output message * Enlarge used grid size * Allow real grid size exceeds expected grid size * Rename interface file * Add client example for grouped conv2d bwd weight * Fix wrong include directive * Rename client example folder
This commit is contained in:
@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
|
||||
std::size_t N = arg.output_.GetLengths()[1];
|
||||
|
||||
std::size_t Ho = arg.output_.GetLengths()[3];
|
||||
std::size_t Wo = arg.output_.GetLengths()[4];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
|
||||
for(std::size_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
|
||||
for(std::size_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv2d backward weight
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv3d backward weight
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
|
||||
is_same_v<OutLayout, NWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
|
||||
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
||||
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,235 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv2d backward weight
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv3d backward weight
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
|
||||
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user