Add client example of grouped conv2d backward data (data type: fp16) (#481)

* Improve example reusability

* Remove no-longer used file

* Rename folder of grouped_conv_bwd_data example

* Add normal grouped conv bwd example

* Add interface 'DeviceGroupedConvBwdData'

* Prettify comment of device op type arguments

* Add grouped conv2d/conv3d backward data fp16 instances

* Fix wrong template argument

* Add grouped_conv2d_bwd_data client example

* Use simpler expression to calculate memory size

* Fix formating

* Remove grouped_conv3d_bw_data instances

Underlying device operator is not ready to handle 3D input

* Remove no-longer necessary include directive

* Add missing include directive

* Use more realistic conv param in example
This commit is contained in:
Po Yen Chen
2022-11-03 06:54:41 +08:00
committed by GitHub
parent 1a0b0e7bec
commit 9e57a290af
15 changed files with 1002 additions and 241 deletions

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NDimSpatial,
typename InputLayout,
typename WeightLayout,
typename OutputLayout,
typename InputDataType,
typename WeightDataType,
typename OutputDataType,
typename InputElementwiseOperation,
typename WeightElementwiseOperation,
typename OutputElementwiseOperation>
struct DeviceGroupedConvBwdData : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_input,
const void* p_weight,
const void* p_output,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& weight_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& weight_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& output_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& output_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const InputElementwiseOperation& input_element_op,
const WeightElementwiseOperation& weight_element_op,
const OutputElementwiseOperation& output_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -6,6 +6,7 @@
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp"
namespace ck {
namespace tensor_operation {
@@ -62,6 +63,100 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout,
BLayout,
Tuple<>,
ELayout,
ADataType,
BDataType,
Tuple<>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
: public DeviceGroupedConvBwdData<NDimSpatial,
ELayout,
BLayout,
ALayout,
EDataType,
BDataType,
ADataType,
CDEElementwiseOperation,
BElementwiseOperation,
AElementwiseOperation>
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // output image
const void* p_b, // weight
const std::array<const void*, 0>&, // bias
void* p_e, // input image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, 0>&, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, 0>&, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_input,
const void* p_weight,
const void* p_output,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& weight_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& weight_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& output_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& output_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const CDEElementwiseOperation& input_element_op,
const BElementwiseOperation& weight_element_op,
const AElementwiseOperation& output_element_op) override final
{
return MakeArgumentPointer(p_output,
p_weight,
std::array<const void*, 0>{},
p_input,
output_g_n_k_wos_lengths,
output_g_n_k_wos_strides,
weight_g_k_c_xs_lengths,
weight_g_k_c_xs_strides,
std::array<std::array<index_t, NDimSpatial + 3>, 0>{},
std::array<std::array<index_t, NDimSpatial + 3>, 0>{},
input_g_n_c_wis_lengths,
input_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
output_element_op,
weight_element_op,
input_element_op);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck