mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Support NHWGC conv2d_bwd_weight (#769)
* Support NHWGC conv2d_bwd_weight
* Fix client example
* Fix client example
* Fix comments
* Redesign grouped_conv_bwd_weight instances
* Clang format fix
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 1ee99dcaa6]
This commit is contained in:
@@ -101,13 +101,15 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool run_grouped_conv_bwd_weight(
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
|
||||
@@ -157,6 +159,8 @@ bool run_grouped_conv_bwd_weight(
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -224,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -22,6 +22,15 @@ static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -31,7 +40,19 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1})
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,17 @@ static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -34,8 +45,19 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1})
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -41,13 +52,15 @@ int main()
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -37,17 +48,20 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
OutLayout>(
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
|
||||
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = BF16;
|
||||
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
|
||||
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
|
||||
@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
|
||||
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
|
||||
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
|
||||
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
|
||||
@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
} // function end
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
InElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_G_;
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -126,6 +126,9 @@ __global__ void
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
@@ -161,29 +164,19 @@ template <ck::index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle;
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
@@ -222,17 +215,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -282,14 +277,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -372,19 +367,25 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
template <ck::index_t NDim,
|
||||
typename ck::enable_if<NDim == 2 &&
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>,
|
||||
bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -447,14 +448,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -539,19 +540,202 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename ck::enable_if<NDim == 2 &&
|
||||
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t HiStride = input_strides[3];
|
||||
const index_t WiStride = input_strides[4];
|
||||
const auto CStride = input_strides[2];
|
||||
|
||||
const index_t WoStride = output_strides[4];
|
||||
const auto KStride = Number<1>{};
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Ho * Wo, K), make_tuple(WoStride, KStride));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Hi * Wi, C), make_tuple(WiStride, CStride));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Ho * Wo, K), make_tuple(WoStride, KStride));
|
||||
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -621,14 +805,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -725,31 +909,70 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
|
||||
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1, 1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
|
||||
1,
|
||||
1,
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
1);
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
// type convert descs
|
||||
@@ -863,19 +1086,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t M01,
|
||||
const ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -913,6 +1138,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -927,18 +1154,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ =
|
||||
N * K *
|
||||
std::accumulate(begin(output_spatial_lengths),
|
||||
end(output_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ =
|
||||
N * C *
|
||||
std::accumulate(begin(input_spatial_lengths),
|
||||
end(input_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = output_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = input_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
K * C *
|
||||
std::accumulate(begin(filter_spatial_lengths),
|
||||
@@ -977,16 +1194,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
WeiElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_G_;
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads_;
|
||||
index_t k_batch_;
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -1091,6 +1308,45 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!(is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout,
|
||||
tensor_layout::convolution::
|
||||
GNHWK>)&&!(is_same_v<InLayout,
|
||||
tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout,
|
||||
tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout,
|
||||
tensor_layout::convolution::NHWGK>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -1134,21 +1390,23 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
const ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
@@ -1160,6 +1418,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1178,21 +1438,23 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k) override
|
||||
const ck::index_t split_k) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
@@ -1204,6 +1466,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1226,7 +1490,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle"
|
||||
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -91,6 +91,42 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv3d backward weight
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
@@ -162,66 +198,103 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
|
||||
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
|
||||
is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
if(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,81 +5,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k]
|
||||
using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_bf16_f32_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_bf16_f32_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -92,12 +28,22 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_bf16_f32_bf16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_bf16_f32_bf16_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
|
||||
1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -5,80 +5,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k]
|
||||
using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f16_default_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -91,11 +28,22 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f16_default_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f16_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
|
||||
1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -5,79 +5,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k]
|
||||
using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f32_default_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -90,11 +28,22 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f32_default_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f32_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
|
||||
1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -2,5 +2,8 @@ add_instance_library(device_grouped_conv2d_bwd_weight_instance
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -1,85 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_bf16_f32_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -92,12 +23,22 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_bf16_f32_bf16_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
|
||||
2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,84 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f16_default_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -91,12 +23,22 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f16_default_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f16_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,83 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f32_default_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -90,12 +23,22 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f32_default_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f32_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
|
||||
2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -5,81 +5,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
|
||||
using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_bf16_f32_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -92,12 +28,22 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_bf16_f32_bf16_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
|
||||
3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -5,81 +5,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
|
||||
using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f16_default_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -92,12 +28,22 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f16_default_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f16_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -5,80 +5,17 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
|
||||
using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -91,12 +28,22 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances{});
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
|
||||
3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -141,3 +141,46 @@ avg_time: 0.768321
|
||||
tflops: 86.6679
|
||||
GB/s: 127.947
|
||||
```
|
||||
|
||||
## Profile grouped convolution backward weight kernels
|
||||
```bash
|
||||
# arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data)
|
||||
# arg2: data type (0: Input fp32, Weight fp32, Output fp32
|
||||
# 1: Input fp16, Weight fp16, Output fp16
|
||||
# 2: Input bf16, Weight fp32, Output bf16)
|
||||
# arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo]
|
||||
# 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
|
||||
# 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]
|
||||
# arg4: verification (0: no, 1: yes)
|
||||
# arg5: initialization (0: no init, 1: integer value, 2: decimal value)
|
||||
# arg6: print tensor value (0: no; 1: yes)
|
||||
# arg7: time kernel (0: no, 1: yes)
|
||||
# Following arguments (depending on number of spatial dims):
|
||||
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)
|
||||
# G, N, K, C,
|
||||
# <filter spatial dimensions>, (ie Y, X for 2D)
|
||||
# <input image spatial dimensions>, (ie Hi, Wi for 2D)
|
||||
# <strides>, (ie Sy, Sx for 2D)
|
||||
# <dilations>, (ie Dy, Dx for 2D)
|
||||
# <left padding>, (ie LeftPy, LeftPx for 2D)
|
||||
# <right padding>, (ie RightPy, RightPx for 2D)
|
||||
# SplitK
|
||||
|
||||
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
|
||||
./bin/ckProfiler grouped_conv_bwd_data 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
|
||||
|
||||
```
|
||||
|
||||
Result (MI100, FP16, GNHWC_GKYXC_GNHWK)
|
||||
```
|
||||
input: dim 5, lengths {32, 512, 1024, 28, 28}, strides {411041792, 802816, 1, 28672, 1024}
|
||||
weight: dim 5, lengths {32, 512, 1024, 3, 3}, strides {4718592, 9216, 1, 3072, 1024}
|
||||
output: dim 5, lengths {32, 512, 512, 26, 26}, strides {177209344, 346112, 1, 13312, 512}
|
||||
....
|
||||
Best configuration parameters:
|
||||
name: DeviceGroupedConvBwdWeight_Xdl_CShuffle<256, 256, 128, 4, Default, 8, 4, 2, 8, 4, 8, 2, 1, 1, 8>
|
||||
avg_time: 68.5216
|
||||
tflops: 95.337
|
||||
GB/s: 69.2301
|
||||
```
|
||||
Note: This kernel use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time.
|
||||
|
||||
@@ -139,6 +139,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
@@ -149,6 +151,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
|
||||
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
|
||||
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
|
||||
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
@@ -167,6 +171,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -15,6 +15,7 @@ enum struct ConvLayout
|
||||
{
|
||||
GNCHW_GKCYX_GNKHW, // 0
|
||||
GNHWC_GKYXC_GNHWK, // 1
|
||||
NHWGC_GKYXC_NHWGK, // 2
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -37,6 +38,8 @@ static void print_helper_msg()
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
"N, Ho, Wo, K]\n"
|
||||
<< " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, "
|
||||
"Ho, Wo, G, K]\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -82,6 +85,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
@@ -90,6 +94,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
@@ -157,6 +162,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
|
||||
@@ -2,8 +2,10 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -1,91 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
ck::index_t split_k{2};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
bool pass;
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<float>, std::tuple<ck::half_t>, std::tuple<ck::bhalf_t>>;
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight, Test1D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
this->template Run<1>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using InDataType = std::tuple_element_t<0, Tuple>;
|
||||
using WeiDataType = std::tuple_element_t<1, Tuple>;
|
||||
using OutDataType = std::tuple_element_t<2, Tuple>;
|
||||
using InLayout = std::tuple_element_t<3, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<4, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<5, Tuple>;
|
||||
using NDimSpatial = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
ck::index_t split_k{2};
|
||||
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight1d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight2d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes1d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>>;
|
||||
using KernelTypes2d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
|
||||
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>>;
|
||||
using KernelTypes3d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->Run();
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using ConvolutionBackwardWeightSpecialization =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default;
|
||||
static constexpr auto Filter1x1Stride1Pad0 =
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <typename Tuple, ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
using InLayout = std::tuple_element_t<2, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<1, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<0, Tuple>;
|
||||
|
||||
// clang-format off
|
||||
using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
//##########| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//##########| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//##########| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
< NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
ck::index_t split_k{2};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
|
||||
|
||||
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
|
||||
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
|
||||
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
|
||||
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
auto conv = GroupedConvBwdWeightDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
split_k);
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<GNHWK, GKYXC, GNHWC>, std::tuple<NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeightDefault
|
||||
: public TestGroupedConvndBwdWeight<Tuple, ConvBwdWeightDefault>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeightFilter1x1
|
||||
: public TestGroupedConvndBwdWeight<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightFilter1x1, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightFilter1x1, SpecializationCheck)
|
||||
{
|
||||
// Check filter 3,3 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check strides 2,2 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check with pad
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Supported version
|
||||
this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
// vector load for B, E, Ds
|
||||
this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
Reference in New Issue
Block a user