mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing * Add new grouped conv fwd kernel for large tensors * Add instances large tensor * Fixes for transform conv to gemm * Fixes * fixes * Remove not needed instances * examples fixes * Remove not need ds arrays * Fix tests * Add 2GB check in gridwise dl * Fixes
This commit is contained in:
@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
template <typename BLay>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
template <typename ELay>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static __device__ __host__ auto MakeArgument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
|
||||
|
||||
static constexpr auto spatial_offset = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer =
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
|
||||
@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
|
||||
: independent_filter_stride;
|
||||
}
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
|
||||
@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
template <typename CLay>
|
||||
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
|
||||
@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
|
||||
@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ALayout,
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise gemm v3 doesn't verify descriptors size
|
||||
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_as,
|
||||
const void* p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
|
||||
}
|
||||
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
typename GridwiseGemm::RsGridPointer p_rs_grid_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
|
||||
@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
static constexpr auto BEnableLds =
|
||||
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc =
|
||||
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
|
||||
using BGridDesc =
|
||||
@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using GemmToConvFwdTransformer =
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
|
||||
b_g_k_c_xs_lengths[I2] = C;
|
||||
c_g_n_k_wos_lengths[I1] = N;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
|
||||
Reference in New Issue
Block a user