mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing * Add new grouped conv fwd kernel for large tensors * Add instances large tensor * Fixes for transform conv to gemm * Fixes * fixes * Remove not needed instances * examples fixes * Remove not need ds arrays * Fix tests * Add 2GB check in gridwise dl * Fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
|
||||
inline HostTensorDescriptor
|
||||
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
|
||||
{
|
||||
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_};
|
||||
std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
|
||||
|
||||
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
|
||||
|
||||
for(ck::index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
input_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
|
||||
filter_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
|
||||
output_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
|
||||
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
|
||||
conv_filter_dilations_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
|
||||
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
|
||||
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
|
||||
}
|
||||
|
||||
// do GEMM
|
||||
auto conv = DeviceConvNdBwdDataInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
|
||||
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_,
|
||||
conv_param.filter_spatial_lengths_,
|
||||
conv_param.GetOutputSpatialLengths(),
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
static_cast<ck::index_t>(conv_param.N_),
|
||||
static_cast<ck::index_t>(conv_param.K_),
|
||||
static_cast<ck::index_t>(conv_param.C_),
|
||||
input_spatial_lengths_i32,
|
||||
filter_spatial_lengths_i32,
|
||||
output_spatial_lengths_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
|
||||
@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
template <typename BLay>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
template <typename ELay>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static __device__ __host__ auto MakeArgument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
|
||||
|
||||
static constexpr auto spatial_offset = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer =
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
|
||||
@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
|
||||
: independent_filter_stride;
|
||||
}
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
|
||||
@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
template <typename CLay>
|
||||
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
|
||||
@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
|
||||
@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ALayout,
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise gemm v3 doesn't verify descriptors size
|
||||
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_as,
|
||||
const void* p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
|
||||
}
|
||||
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
typename GridwiseGemm::RsGridPointer p_rs_grid_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
|
||||
@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
static constexpr auto BEnableLds =
|
||||
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc =
|
||||
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
|
||||
using BGridDesc =
|
||||
@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using GemmToConvFwdTransformer =
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
|
||||
b_g_k_c_xs_lengths[I2] = C;
|
||||
c_g_n_k_wos_lengths[I1] = N;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
|
||||
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
|
||||
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
|
||||
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
|
||||
|
||||
@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1>
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
private:
|
||||
@@ -46,10 +47,10 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
{
|
||||
const long_index_t a_element_space_size =
|
||||
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
|
||||
@@ -59,7 +60,7 @@ struct TransformConvFwdToGemm
|
||||
c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const index_t N = a_g_n_c_wis_lengths[I1];
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if(element_space_size > TwoGB)
|
||||
{
|
||||
@@ -70,7 +71,7 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
// Find least divisor of N larger than element_space_size / TwoGB
|
||||
// Iterate up to sqrt(N). There are no divisors above this value.
|
||||
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
|
||||
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
|
||||
least_divisor++)
|
||||
{
|
||||
if(N % least_divisor == 0)
|
||||
@@ -98,6 +99,53 @@ struct TransformConvFwdToGemm
|
||||
public:
|
||||
__host__ __device__ constexpr TransformConvFwdToGemm() {}
|
||||
|
||||
template <typename TransformConvFwdToGemmBase>
|
||||
__host__ __device__
|
||||
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
|
||||
: N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
|
||||
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
|
||||
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
|
||||
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
|
||||
Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
|
||||
Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
|
||||
Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
|
||||
Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
|
||||
Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
|
||||
X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
|
||||
K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
|
||||
C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
|
||||
DiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DiStride_)},
|
||||
HiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HiStride_)},
|
||||
WiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WiStride_)},
|
||||
DoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DoStride_)},
|
||||
HoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HoStride_)},
|
||||
WoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WoStride_)},
|
||||
XStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.XStride_)},
|
||||
CStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorA_)},
|
||||
CStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorB_)},
|
||||
KStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorB_)},
|
||||
KStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorC_)},
|
||||
NStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorA_)},
|
||||
NStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorC_)},
|
||||
GStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorA_)},
|
||||
GStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorB_)},
|
||||
GStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorC_)},
|
||||
ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
|
||||
ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
|
||||
ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
|
||||
ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
|
||||
ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
|
||||
ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
|
||||
InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
|
||||
InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
|
||||
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
|
||||
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
|
||||
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
|
||||
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
typename ConvSpatialDimsType,
|
||||
index_t NDim = NDimSpatial,
|
||||
@@ -126,6 +174,8 @@ struct TransformConvFwdToGemm
|
||||
DiStride_{I1},
|
||||
HiStride_{I1},
|
||||
WiStride_{a_g_n_c_wis_strides[I3]},
|
||||
DoStride_{I1},
|
||||
HoStride_{I1},
|
||||
WoStride_{c_g_n_k_wos_strides[I3]},
|
||||
XStride_{b_g_k_c_xs_strides[I3]},
|
||||
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
|
||||
@@ -133,6 +183,7 @@ struct TransformConvFwdToGemm
|
||||
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
|
||||
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
|
||||
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
|
||||
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
|
||||
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
|
||||
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
|
||||
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
|
||||
@@ -150,10 +201,10 @@ struct TransformConvFwdToGemm
|
||||
InRightPadW_{input_right_pads[I0]},
|
||||
ZYX_{X_}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
@@ -164,7 +215,6 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
NDoHoWo_ = N_ * Wo_;
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -195,6 +245,8 @@ struct TransformConvFwdToGemm
|
||||
DiStride_{I1},
|
||||
HiStride_{a_g_n_c_wis_strides[I3]},
|
||||
WiStride_{a_g_n_c_wis_strides[I4]},
|
||||
DoStride_{I1},
|
||||
HoStride_{c_g_n_k_wos_strides[I3]},
|
||||
WoStride_{c_g_n_k_wos_strides[I4]},
|
||||
XStride_{b_g_k_c_xs_strides[I4]},
|
||||
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
|
||||
@@ -202,6 +254,7 @@ struct TransformConvFwdToGemm
|
||||
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
|
||||
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
|
||||
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
|
||||
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
|
||||
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
|
||||
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
|
||||
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
|
||||
@@ -219,10 +272,10 @@ struct TransformConvFwdToGemm
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
@@ -233,7 +286,6 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
NDoHoWo_ = N_ * Ho_ * Wo_;
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -264,6 +316,8 @@ struct TransformConvFwdToGemm
|
||||
DiStride_{a_g_n_c_wis_strides[I3]},
|
||||
HiStride_{a_g_n_c_wis_strides[I4]},
|
||||
WiStride_{a_g_n_c_wis_strides[I5]},
|
||||
DoStride_{c_g_n_k_wos_strides[I3]},
|
||||
HoStride_{c_g_n_k_wos_strides[I4]},
|
||||
WoStride_{c_g_n_k_wos_strides[I5]},
|
||||
XStride_{b_g_k_c_xs_strides[I5]},
|
||||
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
|
||||
@@ -271,6 +325,7 @@ struct TransformConvFwdToGemm
|
||||
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
|
||||
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
|
||||
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
|
||||
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
|
||||
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
|
||||
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
|
||||
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
|
||||
@@ -288,10 +343,10 @@ struct TransformConvFwdToGemm
|
||||
InRightPadW_{input_right_pads[I2]},
|
||||
ZYX_{Z_ * Y_ * X_}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
@@ -302,7 +357,122 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
|
||||
}
|
||||
|
||||
__host__ bool AreDescriptorsSmallerThan2GB() const
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const long_index_t in_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
|
||||
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
|
||||
const long_index_t out_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
|
||||
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
|
||||
|
||||
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
|
||||
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
|
||||
|
||||
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
|
||||
}
|
||||
|
||||
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
|
||||
CDataType* c_grid_ptr_base) const
|
||||
{
|
||||
// Create copies
|
||||
auto conv_to_gemm_transformer_left = *this;
|
||||
auto conv_to_gemm_transformer_right = *this;
|
||||
IndexType a_right_offset = 0;
|
||||
IndexType c_right_offset = 0;
|
||||
// Calculate real filter size
|
||||
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
|
||||
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
|
||||
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
|
||||
// Calculate start position in input for right tensor
|
||||
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
|
||||
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
|
||||
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
|
||||
// Calculate last position in input for left tensor
|
||||
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
|
||||
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
|
||||
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
|
||||
// Allow to split if whole left padding will be in left tensor and right padding in right
|
||||
// tensor
|
||||
const bool is_possible_to_split_d = Do_ != 1 &&
|
||||
di_right_transformer_start_idx > InLeftPadD_ &&
|
||||
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
|
||||
const bool is_possible_to_split_h = Ho_ != 1 &&
|
||||
hi_right_transformer_start_idx > InLeftPadH_ &&
|
||||
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
|
||||
const bool is_possible_to_split_w = Wo_ != 1 &&
|
||||
wi_right_transformer_start_idx > InLeftPadW_ &&
|
||||
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
|
||||
|
||||
if(is_possible_to_split_d)
|
||||
{
|
||||
// Apply new sizes
|
||||
// Split output on half
|
||||
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
|
||||
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
|
||||
// Assign left padding to left convolution
|
||||
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
|
||||
// Assign right padding to right convolution
|
||||
conv_to_gemm_transformer_left.InRightPadD_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
|
||||
// Calculate new input size
|
||||
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.Di_ =
|
||||
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
|
||||
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
|
||||
;
|
||||
// Calcualte offsets
|
||||
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
|
||||
c_right_offset = (Do_ / 2) * DoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_h)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
|
||||
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadH_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
|
||||
|
||||
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.Hi_ =
|
||||
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
|
||||
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
|
||||
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
|
||||
c_right_offset = (Ho_ / 2) * HoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_w)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
|
||||
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadW_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
|
||||
|
||||
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.Wi_ =
|
||||
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
|
||||
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
|
||||
|
||||
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
|
||||
c_right_offset = (Wo_ / 2) * WoStride_;
|
||||
}
|
||||
// Return left transform, right transformer, right offset to Input and right offset to
|
||||
// Output
|
||||
return ck::make_tuple(conv_to_gemm_transformer_left,
|
||||
conv_to_gemm_transformer_right,
|
||||
a_grid_ptr_base + a_right_offset,
|
||||
c_grid_ptr_base + c_right_offset);
|
||||
}
|
||||
|
||||
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
@@ -320,20 +490,27 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
|
||||
make_tuple(WiStride_, CStrideTensorA_));
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -527,20 +704,29 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
|
||||
make_tuple(WiStride_, CStrideTensorA_));
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -759,20 +945,34 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
|
||||
make_tuple(WiStride_, CStrideTensorA_));
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_,
|
||||
DiStride_,
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -1119,45 +1319,70 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename CLayout,
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<NDimSp == 2 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<NDimSp == 3 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
typename std::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNWK>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo_),
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
@@ -1167,7 +1392,7 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo_),
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
@@ -1175,45 +1400,146 @@ struct TransformConvFwdToGemm
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// for output bias
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
|
||||
bool>::type = false>
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<
|
||||
NDimSp == 2 && (is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNHWK>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(NStrideTensorC_,
|
||||
HoStride_,
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
index_t N_;
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
typename std::enable_if<
|
||||
NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
|
||||
private:
|
||||
const index_t Di_, Hi_, Wi_;
|
||||
const index_t Do_, Ho_, Wo_;
|
||||
const index_t Z_, Y_, X_;
|
||||
const index_t K_, C_;
|
||||
const index_t DiStride_, HiStride_, WiStride_;
|
||||
const index_t WoStride_;
|
||||
const index_t XStride_;
|
||||
const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
|
||||
const index_t NStrideTensorA_;
|
||||
const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
|
||||
const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
|
||||
const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
const index_t InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
const index_t ZYX_;
|
||||
index_t NDoHoWo_;
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(NStrideTensorC_,
|
||||
DoStride_,
|
||||
HoStride_,
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
IndexType N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
IndexType Z_, Y_, X_;
|
||||
IndexType K_, C_;
|
||||
IndexType DiStride_, HiStride_, WiStride_;
|
||||
IndexType DoStride_, HoStride_, WoStride_;
|
||||
IndexType XStride_;
|
||||
IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
|
||||
IndexType NStrideTensorA_, NStrideTensorC_;
|
||||
IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
|
||||
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
|
||||
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
};
|
||||
|
||||
// wrapper class to call member functions on TransformConvToGemm struct at runtime
|
||||
@@ -1230,17 +1556,17 @@ struct TransformConv
|
||||
if(NDimSpatial == 2)
|
||||
{
|
||||
return conv_fwd_to_gemm
|
||||
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
|
||||
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK, 2>();
|
||||
}
|
||||
else if(NDimSpatial == 3)
|
||||
{
|
||||
return conv_fwd_to_gemm
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK, 3>();
|
||||
}
|
||||
else if(NDimSpatial == 1)
|
||||
{
|
||||
return conv_fwd_to_gemm
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK, 1>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/f8_utils.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/array.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for gfx94x models
|
||||
@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Y, typename X, std::size_t NumElems>
|
||||
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
|
||||
const std::array<X, NumElems>& x)
|
||||
{
|
||||
for(std::size_t i = 0; i < NumElems; i++)
|
||||
{
|
||||
y[i] = type_convert<Y>(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y, typename X, index_t NumElems>
|
||||
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
|
||||
{
|
||||
for(std::size_t i = 0; i < NumElems; i++)
|
||||
{
|
||||
y[i] = type_convert<Y>(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
public:
|
||||
Argument(const Tensor<InDataType>& input,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads)
|
||||
: input_{input},
|
||||
output_{output},
|
||||
conv_strides_{conv_filter_strides},
|
||||
@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
const Tensor<InDataType>& input_;
|
||||
Tensor<OutDataType>& output_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
std::vector<index_t> in_left_pads_;
|
||||
std::vector<index_t> in_right_pads_;
|
||||
std::vector<long_index_t> conv_strides_;
|
||||
std::vector<long_index_t> conv_dilations_;
|
||||
std::vector<long_index_t> in_left_pads_;
|
||||
std::vector<long_index_t> in_right_pads_;
|
||||
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<long_index_t> filter_spatial_lengths_;
|
||||
std::vector<long_index_t> output_spatial_lengths_;
|
||||
|
||||
private:
|
||||
void initOutputSpatialLengths()
|
||||
{
|
||||
constexpr auto input_offset_to_spatial = 3;
|
||||
|
||||
for(ck::index_t i = 0; i < NDimSpatial; ++i)
|
||||
for(ck::long_index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
|
||||
const ck::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_.push_back(
|
||||
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
|
||||
@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
const index_t G = arg.output_.GetLengths()[0];
|
||||
const index_t N = arg.output_.GetLengths()[1];
|
||||
const index_t C = arg.output_.GetLengths()[2];
|
||||
const long_index_t G = arg.output_.GetLengths()[0];
|
||||
const long_index_t N = arg.output_.GetLengths()[1];
|
||||
const long_index_t C = arg.output_.GetLengths()[2];
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const index_t Wo = arg.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
const long_index_t Wo = arg.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(long_index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
index_t row = n * Wo + wo;
|
||||
index_t column = 0;
|
||||
long_index_t row = n * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
|
||||
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
|
||||
@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const index_t Ho = arg.output_spatial_lengths_[0];
|
||||
const index_t Wo = arg.output_spatial_lengths_[1];
|
||||
const long_index_t Ho = arg.output_spatial_lengths_[0];
|
||||
const long_index_t Wo = arg.output_spatial_lengths_[1];
|
||||
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
for(long_index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
for(long_index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
index_t column = 0;
|
||||
long_index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
|
||||
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
auto hi =
|
||||
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
|
||||
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
|
||||
if(hi >= 0 &&
|
||||
@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const index_t Do = arg.output_spatial_lengths_[0];
|
||||
const index_t Ho = arg.output_spatial_lengths_[1];
|
||||
const index_t Wo = arg.output_spatial_lengths_[2];
|
||||
const long_index_t Do = arg.output_spatial_lengths_[0];
|
||||
const long_index_t Ho = arg.output_spatial_lengths_[1];
|
||||
const long_index_t Wo = arg.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(index_t d_o = 0; d_o < Do; ++d_o)
|
||||
for(long_index_t d_o = 0; d_o < Do; ++d_o)
|
||||
{
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
for(long_index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
for(long_index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
index_t column = 0;
|
||||
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
|
||||
for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
auto di =
|
||||
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
|
||||
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
auto hi =
|
||||
static_cast<ck::long_index_t>(ho *
|
||||
@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
static_cast<ck::long_index_t>(y *
|
||||
arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
|
||||
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2];
|
||||
++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<ck::long_index_t>(
|
||||
@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
static_cast<ck::long_index_t>(
|
||||
x * arg.conv_dilations_[2]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(di >= 0 &&
|
||||
ck::type_convert<std::size_t>(di) <
|
||||
@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const ck::index_t G = arg.output_.GetLengths()[0];
|
||||
const ck::index_t N = arg.output_.GetLengths()[1];
|
||||
const ck::index_t C = arg.output_.GetLengths()[2];
|
||||
const ck::long_index_t G = arg.output_.GetLengths()[0];
|
||||
const ck::long_index_t N = arg.output_.GetLengths()[1];
|
||||
const ck::long_index_t C = arg.output_.GetLengths()[2];
|
||||
|
||||
const index_t NDoHoWo =
|
||||
N * ck::accumulate_n<index_t>(
|
||||
const long_index_t NDoHoWo =
|
||||
N * ck::accumulate_n<long_index_t>(
|
||||
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t CZYX =
|
||||
C * ck::accumulate_n<index_t>(
|
||||
const long_index_t CZYX =
|
||||
C * ck::accumulate_n<long_index_t>(
|
||||
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
|
||||
@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& input,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads)
|
||||
{
|
||||
return Argument{input,
|
||||
output,
|
||||
|
||||
@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
const Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
|
||||
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
std::vector<index_t> in_left_pads_;
|
||||
std::vector<index_t> in_right_pads_;
|
||||
std::vector<long_index_t> conv_strides_;
|
||||
std::vector<long_index_t> conv_dilations_;
|
||||
std::vector<long_index_t> in_left_pads_;
|
||||
std::vector<long_index_t> in_right_pads_;
|
||||
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
const Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
|
||||
@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
const Tensor<InDataType>& in_n_c_hi_wi,
|
||||
Tensor<WeiDataType>& wei_k_c_y_x,
|
||||
const Tensor<OutDataType>& out_n_k_ho_wo,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
|
||||
const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
std::vector<index_t> in_left_pads_;
|
||||
std::vector<index_t> in_right_pads_;
|
||||
std::vector<long_index_t> conv_strides_;
|
||||
std::vector<long_index_t> conv_dilations_;
|
||||
std::vector<long_index_t> in_left_pads_;
|
||||
std::vector<long_index_t> in_right_pads_;
|
||||
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
const Tensor<InDataType>& in_n_c_hi_wi,
|
||||
Tensor<WeiDataType>& wei_k_c_y_x,
|
||||
const Tensor<OutDataType>& out_n_k_ho_wo,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
|
||||
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
std::vector<index_t> in_left_pads_;
|
||||
std::vector<index_t> in_right_pads_;
|
||||
std::vector<ck::long_index_t> conv_strides_;
|
||||
std::vector<ck::long_index_t> conv_dilations_;
|
||||
std::vector<ck::long_index_t> in_left_pads_;
|
||||
std::vector<ck::long_index_t> in_right_pads_;
|
||||
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
|
||||
@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
public:
|
||||
Argument(const Tensor<InDataType>& input,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads)
|
||||
: input_{input},
|
||||
output_{output},
|
||||
conv_strides_{conv_filter_strides},
|
||||
@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
const Tensor<InDataType>& input_;
|
||||
Tensor<OutDataType>& output_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
std::vector<index_t> in_left_pads_;
|
||||
std::vector<index_t> in_right_pads_;
|
||||
std::vector<long_index_t> conv_strides_;
|
||||
std::vector<long_index_t> conv_dilations_;
|
||||
std::vector<long_index_t> in_left_pads_;
|
||||
std::vector<long_index_t> in_right_pads_;
|
||||
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<long_index_t> filter_spatial_lengths_;
|
||||
std::vector<long_index_t> output_spatial_lengths_;
|
||||
|
||||
private:
|
||||
void initOutputSpatialLengths()
|
||||
@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
|
||||
const ck::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_.push_back(
|
||||
(input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
|
||||
@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
const index_t G = arg.input_.GetLengths()[0];
|
||||
const index_t N = arg.input_.GetLengths()[1];
|
||||
const index_t C = arg.input_.GetLengths()[2];
|
||||
const long_index_t G = arg.input_.GetLengths()[0];
|
||||
const long_index_t N = arg.input_.GetLengths()[1];
|
||||
const long_index_t C = arg.input_.GetLengths()[2];
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const index_t Wo = arg.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
index_t row = n * Wo + wo;
|
||||
index_t column = 0;
|
||||
const long_index_t Wo = arg.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
long_index_t row = n * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
|
||||
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
|
||||
@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const index_t Ho = arg.output_spatial_lengths_[0];
|
||||
const index_t Wo = arg.output_spatial_lengths_[1];
|
||||
const long_index_t Ho = arg.output_spatial_lengths_[0];
|
||||
const long_index_t Wo = arg.output_spatial_lengths_[1];
|
||||
|
||||
auto func = [&](auto g, auto n, auto ho, auto wo) {
|
||||
index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
index_t column = 0;
|
||||
long_index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
|
||||
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
|
||||
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
|
||||
if(hi >= 0 &&
|
||||
@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const index_t Do = arg.output_spatial_lengths_[0];
|
||||
const index_t Ho = arg.output_spatial_lengths_[1];
|
||||
const index_t Wo = arg.output_spatial_lengths_[2];
|
||||
const long_index_t Do = arg.output_spatial_lengths_[0];
|
||||
const long_index_t Ho = arg.output_spatial_lengths_[1];
|
||||
const long_index_t Wo = arg.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
|
||||
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
index_t column = 0;
|
||||
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
|
||||
for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
|
||||
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
|
||||
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(di >= 0 &&
|
||||
ck::type_convert<std::size_t>(di) <
|
||||
@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const ck::index_t G = arg.input_.GetLengths()[0];
|
||||
const ck::index_t N = arg.input_.GetLengths()[1];
|
||||
const ck::index_t C = arg.input_.GetLengths()[2];
|
||||
const ck::long_index_t G = arg.input_.GetLengths()[0];
|
||||
const ck::long_index_t N = arg.input_.GetLengths()[1];
|
||||
const ck::long_index_t C = arg.input_.GetLengths()[2];
|
||||
|
||||
const index_t NDoHoWo =
|
||||
N * ck::accumulate_n<index_t>(
|
||||
const long_index_t NDoHoWo =
|
||||
N * ck::accumulate_n<long_index_t>(
|
||||
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t CZYX =
|
||||
C * ck::accumulate_n<index_t>(
|
||||
const long_index_t CZYX =
|
||||
C * ck::accumulate_n<long_index_t>(
|
||||
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
|
||||
@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& input,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths,
|
||||
std::vector<ck::long_index_t> conv_filter_strides,
|
||||
std::vector<ck::long_index_t> conv_filter_dilations,
|
||||
std::vector<ck::long_index_t> input_left_pads,
|
||||
std::vector<ck::long_index_t> input_right_pads)
|
||||
{
|
||||
return Argument{input,
|
||||
output,
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
|
||||
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -17,6 +17,7 @@
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_forward_xdl.inc"
|
||||
#include "grouped_convolution_forward_xdl_large_tensor.inc"
|
||||
#include "grouped_convolution_forward_xdl_merged_groups.inc"
|
||||
#include "grouped_convolution_forward_comp_xdl.inc"
|
||||
#include "grouped_convolution_forward_mem_inter_xdl.inc"
|
||||
@@ -200,6 +201,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs);
|
||||
@@ -215,6 +218,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs);
|
||||
@@ -232,6 +237,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs);
|
||||
@@ -291,6 +298,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs);
|
||||
@@ -347,6 +356,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs);
|
||||
@@ -364,6 +375,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs);
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -31,23 +31,35 @@ struct ConvParam
|
||||
const std::vector<ck::index_t>& left_pads,
|
||||
const std::vector<ck::index_t>& right_pads);
|
||||
|
||||
ck::index_t num_dim_spatial_;
|
||||
ck::index_t G_;
|
||||
ck::index_t N_;
|
||||
ck::index_t K_;
|
||||
ck::index_t C_;
|
||||
ConvParam(ck::long_index_t n_dim,
|
||||
ck::long_index_t group_count,
|
||||
ck::long_index_t n_batch,
|
||||
ck::long_index_t n_out_channels,
|
||||
ck::long_index_t n_in_channels,
|
||||
const std::vector<ck::long_index_t>& filters_len,
|
||||
const std::vector<ck::long_index_t>& input_len,
|
||||
const std::vector<ck::long_index_t>& strides,
|
||||
const std::vector<ck::long_index_t>& dilations,
|
||||
const std::vector<ck::long_index_t>& left_pads,
|
||||
const std::vector<ck::long_index_t>& right_pads);
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths_;
|
||||
std::vector<ck::index_t> input_spatial_lengths_;
|
||||
std::vector<ck::index_t> output_spatial_lengths_;
|
||||
ck::long_index_t num_dim_spatial_;
|
||||
ck::long_index_t G_;
|
||||
ck::long_index_t N_;
|
||||
ck::long_index_t K_;
|
||||
ck::long_index_t C_;
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_;
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths_;
|
||||
std::vector<ck::long_index_t> input_spatial_lengths_;
|
||||
std::vector<ck::long_index_t> output_spatial_lengths_;
|
||||
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
std::vector<ck::long_index_t> conv_filter_strides_;
|
||||
std::vector<ck::long_index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck::index_t> GetOutputSpatialLengths() const;
|
||||
std::vector<ck::long_index_t> input_left_pads_;
|
||||
std::vector<ck::long_index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck::long_index_t> GetOutputSpatialLengths() const;
|
||||
|
||||
std::size_t GetFlops() const;
|
||||
|
||||
|
||||
@@ -96,9 +96,16 @@ struct HostTensorDescriptor
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens)
|
||||
: mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename Lengths,
|
||||
typename = std::enable_if_t<
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t>>>
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> ||
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t>>>
|
||||
HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
@@ -114,11 +121,19 @@ struct HostTensorDescriptor
|
||||
{
|
||||
}
|
||||
|
||||
HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens,
|
||||
const std::initializer_list<ck::long_index_t>& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Lengths,
|
||||
typename Strides,
|
||||
typename = std::enable_if_t<
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>>>
|
||||
(std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>) ||
|
||||
(std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t> &&
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Strides>, ck::long_index_t>)>>
|
||||
HostTensorDescriptor(const Lengths& lens, const Strides& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
|
||||
@@ -9,6 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
# large tensor
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
# merged groups
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
@@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim,
|
||||
const std::vector<ck::index_t>& dilations,
|
||||
const std::vector<ck::index_t>& left_pads,
|
||||
const std::vector<ck::index_t>& right_pads)
|
||||
: num_dim_spatial_(static_cast<ck::long_index_t>(n_dim)),
|
||||
G_(static_cast<ck::long_index_t>(group_count)),
|
||||
N_(static_cast<ck::long_index_t>(n_batch)),
|
||||
K_(static_cast<ck::long_index_t>(n_out_channels)),
|
||||
C_(static_cast<ck::long_index_t>(n_in_channels)),
|
||||
filter_spatial_lengths_(num_dim_spatial_),
|
||||
input_spatial_lengths_(num_dim_spatial_),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(num_dim_spatial_),
|
||||
conv_filter_dilations_(num_dim_spatial_),
|
||||
input_left_pads_(num_dim_spatial_),
|
||||
input_right_pads_(num_dim_spatial_)
|
||||
{
|
||||
if(static_cast<ck::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(
|
||||
std::runtime_error("ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
filter_spatial_lengths_[i] = static_cast<ck::long_index_t>(filters_len[i]);
|
||||
input_spatial_lengths_[i] = static_cast<ck::long_index_t>(input_len[i]);
|
||||
conv_filter_strides_[i] = static_cast<ck::long_index_t>(strides[i]);
|
||||
conv_filter_dilations_[i] = static_cast<ck::long_index_t>(dilations[i]);
|
||||
input_left_pads_[i] = static_cast<ck::long_index_t>(left_pads[i]);
|
||||
input_right_pads_[i] = static_cast<ck::long_index_t>(right_pads[i]);
|
||||
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ConvParam::ConvParam(ck::long_index_t n_dim,
|
||||
ck::long_index_t group_count,
|
||||
ck::long_index_t n_batch,
|
||||
ck::long_index_t n_out_channels,
|
||||
ck::long_index_t n_in_channels,
|
||||
const std::vector<ck::long_index_t>& filters_len,
|
||||
const std::vector<ck::long_index_t>& input_len,
|
||||
const std::vector<ck::long_index_t>& strides,
|
||||
const std::vector<ck::long_index_t>& dilations,
|
||||
const std::vector<ck::long_index_t>& left_pads,
|
||||
const std::vector<ck::long_index_t>& right_pads)
|
||||
: num_dim_spatial_(n_dim),
|
||||
G_(group_count),
|
||||
N_(n_batch),
|
||||
@@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim,
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
const ck::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
@@ -63,7 +121,7 @@ ConvParam::ConvParam()
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
|
||||
std::vector<ck::long_index_t> ConvParam::GetOutputSpatialLengths() const
|
||||
{
|
||||
return output_spatial_lengths_;
|
||||
}
|
||||
@@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg()
|
||||
|
||||
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
|
||||
{
|
||||
const ck::index_t G = std::stoi(argv[arg_idx++]);
|
||||
const ck::index_t N = std::stoi(argv[arg_idx++]);
|
||||
const ck::index_t K = std::stoi(argv[arg_idx++]);
|
||||
const ck::index_t C = std::stoi(argv[arg_idx++]);
|
||||
const ck::long_index_t G = std::stol(argv[arg_idx++]);
|
||||
const ck::long_index_t N = std::stol(argv[arg_idx++]);
|
||||
const ck::long_index_t K = std::stol(argv[arg_idx++]);
|
||||
const ck::long_index_t C = std::stol(argv[arg_idx++]);
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> input_right_pads(num_dim_spatial);
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
input_left_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
input_right_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return ck::utils::conv::ConvParam{num_dim_spatial,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -82,6 +82,29 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
|
||||
|
||||
for(ck::index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
input_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
|
||||
filter_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
|
||||
output_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
|
||||
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
|
||||
conv_filter_dilations_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
|
||||
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
|
||||
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
|
||||
}
|
||||
|
||||
std::cout << "input: " << input_host_result.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
@@ -161,16 +184,16 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_,
|
||||
conv_param.filter_spatial_lengths_,
|
||||
conv_param.output_spatial_lengths_,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
static_cast<ck::index_t>(conv_param.N_),
|
||||
static_cast<ck::index_t>(conv_param.K_),
|
||||
static_cast<ck::index_t>(conv_param.C_),
|
||||
input_spatial_lengths_i32,
|
||||
filter_spatial_lengths_i32,
|
||||
output_spatial_lengths_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -60,6 +60,29 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
|
||||
|
||||
for(ck::index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
input_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
|
||||
filter_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
|
||||
output_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
|
||||
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
|
||||
conv_filter_dilations_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
|
||||
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
|
||||
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
|
||||
}
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << host_output.mDesc << std::endl;
|
||||
@@ -143,16 +166,16 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_,
|
||||
conv_param.filter_spatial_lengths_,
|
||||
conv_param.GetOutputSpatialLengths(),
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
static_cast<ck::index_t>(conv_param.N_),
|
||||
static_cast<ck::index_t>(conv_param.K_),
|
||||
static_cast<ck::index_t>(conv_param.C_),
|
||||
input_spatial_lengths_i32,
|
||||
filter_spatial_lengths_i32,
|
||||
output_spatial_lengths_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
@@ -33,7 +33,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType>
|
||||
typename BComputeType = AComputeType,
|
||||
typename IndexType = ck::index_t>
|
||||
bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -57,16 +58,16 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<IndexType, NDimSpatial> conv_filter_strides{};
|
||||
std::array<IndexType, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<IndexType, NDimSpatial> input_left_pads{};
|
||||
std::array<IndexType, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
|
||||
@@ -29,6 +29,12 @@ enum struct ConvDataType
|
||||
BF8_F8_F8, // 7
|
||||
};
|
||||
|
||||
enum struct IndexType
|
||||
{
|
||||
INDEX_T, // 0
|
||||
LONG_INDEX_T, // 1
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_fwd"
|
||||
#define OP_DESC "Grouped Convolution Forward"
|
||||
|
||||
@@ -45,12 +51,13 @@ static void print_helper_msg()
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
|
||||
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< "arg5: verification (0: no, 1: yes)\n"
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg8: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
// clang-format on
|
||||
}
|
||||
@@ -60,7 +67,7 @@ static void print_helper_msg()
|
||||
int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 9)
|
||||
if(argc < 10)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -68,20 +75,21 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const int init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
const auto index_type = static_cast<IndexType>(std::stoi(argv[4]));
|
||||
const bool do_verification = std::stoi(argv[5]);
|
||||
const int init_method = std::stoi(argv[6]);
|
||||
const bool do_log = std::stoi(argv[7]);
|
||||
const bool time_kernel = std::stoi(argv[8]);
|
||||
const int num_dim_spatial = std::stoi(argv[9]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial)
|
||||
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
@@ -138,18 +146,43 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using AComputeType = decltype(a_compute_type);
|
||||
using BComputeType = decltype(b_compute_type);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AComputeType,
|
||||
BComputeType>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
if(index_type == IndexType::INDEX_T)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
ck::index_t>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
else if(index_type == IndexType::LONG_INDEX_T)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
ck::long_index_t>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this indexing data type is not implemented" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
// GNHWC_GKYXC_GNHWK
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
@@ -24,12 +24,12 @@ class TestConvUtil : public ::testing::Test
|
||||
128,
|
||||
192,
|
||||
256,
|
||||
std::vector<ck::index_t>(ndims, 3),
|
||||
std::vector<ck::index_t>(ndims, 71),
|
||||
std::vector<ck::index_t>(ndims, s),
|
||||
std::vector<ck::index_t>(ndims, d),
|
||||
std::vector<ck::index_t>(ndims, p),
|
||||
std::vector<ck::index_t>(ndims, p));
|
||||
std::vector<ck::long_index_t>(ndims, 3),
|
||||
std::vector<ck::long_index_t>(ndims, 71),
|
||||
std::vector<ck::long_index_t>(ndims, s),
|
||||
std::vector<ck::long_index_t>(ndims, d),
|
||||
std::vector<ck::long_index_t>(ndims, p),
|
||||
std::vector<ck::long_index_t>(ndims, p));
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -48,35 +48,35 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
|
||||
{
|
||||
// stride 2, dilation 1, pad 1
|
||||
SetNDParams(1, 2, 1, 1);
|
||||
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
|
||||
out_spatial_len, std::vector<ck::long_index_t>{36}, "Error: ConvParams 1D."));
|
||||
|
||||
// stride 1, dilation 1, pad 1
|
||||
SetNDParams(1, 1, 1, 1);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
|
||||
out_spatial_len, std::vector<ck::long_index_t>{71}, "Error: ConvParams 1D stride {1}."));
|
||||
|
||||
// stride 2, dilation 1, pad 2
|
||||
SetNDParams(1, 2, 1, 2);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{37},
|
||||
std::vector<ck::long_index_t>{37},
|
||||
"Error: ConvParams 1D padding left/right {2}."));
|
||||
|
||||
// stride 2, dilation 2, pad 2
|
||||
SetNDParams(1, 2, 2, 2);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
|
||||
out_spatial_len, std::vector<ck::long_index_t>{36}, "Error: ConvParams 1D dilation {2}."));
|
||||
|
||||
// stride 3, dilation 2, pad 1
|
||||
SetNDParams(1, 3, 2, 1);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(
|
||||
ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{23},
|
||||
std::vector<ck::long_index_t>{23},
|
||||
"Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
|
||||
}
|
||||
|
||||
@@ -84,36 +84,38 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
|
||||
{
|
||||
// stride 2, dilation 1, pad 1
|
||||
SetNDParams(2, 2, 1, 1);
|
||||
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{36, 36},
|
||||
std::vector<ck::long_index_t>{36, 36},
|
||||
"Error: ConvParams 2D default constructor."));
|
||||
|
||||
// stride 1, dilation 1, pad 1
|
||||
SetNDParams(2, 1, 1, 1);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::long_index_t>{71, 71},
|
||||
"Error: ConvParams 2D stride {1,1}."));
|
||||
|
||||
// stride 2, dilation 1, pad 2
|
||||
SetNDParams(2, 2, 1, 2);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{37, 37},
|
||||
std::vector<ck::long_index_t>{37, 37},
|
||||
"Error: ConvParams 2D padding left/right {2,2}."));
|
||||
|
||||
// stride 2, dilation 2, pad 2
|
||||
SetNDParams(2, 2, 2, 2);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::long_index_t>{36, 36},
|
||||
"Error: ConvParams 2D dilation {2,2}."));
|
||||
|
||||
// stride 3, dilation 2, pad 1
|
||||
SetNDParams(2, 3, 2, 1);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(
|
||||
ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{23, 23},
|
||||
std::vector<ck::long_index_t>{23, 23},
|
||||
"Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
|
||||
}
|
||||
|
||||
@@ -121,29 +123,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
|
||||
{
|
||||
// stride 2, dilation 1, pad 1
|
||||
SetNDParams(3, 2, 1, 1);
|
||||
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));
|
||||
out_spatial_len, std::vector<ck::long_index_t>{36, 36, 36}, "Error: ConvParams 3D."));
|
||||
|
||||
// stride 1, dilation 1, pad 1
|
||||
SetNDParams(3, 1, 1, 1);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{71, 71, 71},
|
||||
std::vector<ck::long_index_t>{71, 71, 71},
|
||||
"Error: ConvParams 3D stride {1, 1, 1}."));
|
||||
|
||||
// stride 2, dilation 1, pad 2
|
||||
SetNDParams(3, 2, 1, 2);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{37, 37, 37},
|
||||
std::vector<ck::long_index_t>{37, 37, 37},
|
||||
"Error: ConvParams 3D padding left/right {2, 2, 2}."));
|
||||
|
||||
// stride 2, dilation 2, pad 2
|
||||
SetNDParams(3, 2, 2, 2);
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
|
||||
std::vector<ck::index_t>{36, 36, 36},
|
||||
std::vector<ck::long_index_t>{36, 36, 36},
|
||||
"Error: ConvParams 3D dilation {2, 2, 2}."));
|
||||
|
||||
// stride 3, dilation 2, pad 1
|
||||
@@ -151,6 +153,6 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
EXPECT_TRUE(ck::utils::check_err(
|
||||
out_spatial_len,
|
||||
std::vector<ck::index_t>{23, 23, 23},
|
||||
std::vector<ck::long_index_t>{23, 23, 23},
|
||||
"Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
using InLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<3, Tuple>;
|
||||
using IndexType = std::tuple_element_t<4, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
@@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
@@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>,
|
||||
std::tuple<ck::half_t, GNWC, GKXC, GNWK>,
|
||||
std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK>,
|
||||
std::tuple<int8_t, GNWC, GKXC, GNWK>>;
|
||||
using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK, ck::index_t>,
|
||||
std::tuple<ck::half_t, GNWC, GKXC, GNWK, ck::index_t>,
|
||||
std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK, ck::index_t>,
|
||||
std::tuple<int8_t, GNWC, GKXC, GNWK, ck::index_t>>;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
|
||||
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK>,
|
||||
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK>,
|
||||
std::tuple<int8_t, GNHWC, GKYXC, GNHWK>,
|
||||
std::tuple<float, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>;
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK, ck::index_t>,
|
||||
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
|
||||
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
|
||||
std::tuple<int8_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
|
||||
std::tuple<float, NHWGC, GKYXC, NHWGK, ck::index_t>,
|
||||
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK, ck::index_t>,
|
||||
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::index_t>,
|
||||
std::tuple<int8_t, NHWGC, GKYXC, NHWGK, ck::index_t>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
|
||||
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
|
||||
std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK>,
|
||||
std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK>,
|
||||
std::tuple<float, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>;
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
|
||||
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
|
||||
std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
|
||||
std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
|
||||
std::tuple<float, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
|
||||
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
|
||||
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
|
||||
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>>;
|
||||
|
||||
using KernelTypes2dLargeCases = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>>;
|
||||
using KernelTypes2dLargeCases =
|
||||
::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK, ck::long_index_t>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
|
||||
@@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
|
||||
// With supported NumGroupsToMerge > 1
|
||||
this->conv_params.push_back(
|
||||
{2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}});
|
||||
// When image is larger than 2GB
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 1, 256, 256, {3, 3}, {4096, 2048}, {1024, 1024}, {3, 3}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user