mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing
* Add new grouped conv fwd kernel for large tensors
* Add instances large tensor
* Fixes for transform conv to gemm
* Fixes
* fixes
* Remove not needed instances
* examples fixes
* Remove not need ds arrays
* Fix tests
* Add 2GB check in gridwise dl
* Fixes
[ROCm/composable_kernel commit: 4ec5c52a0c]
This commit is contained in:
@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
|
||||
@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
template <typename BLay>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
template <typename ELay>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static __device__ __host__ auto MakeArgument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
|
||||
|
||||
static constexpr auto spatial_offset = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer =
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
|
||||
@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
|
||||
: independent_filter_stride;
|
||||
}
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
|
||||
@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
template <typename CLay>
|
||||
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
|
||||
@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
|
||||
@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ALayout,
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
index_t conv_N_per_block_;
|
||||
|
||||
@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise gemm v3 doesn't verify descriptors size
|
||||
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_as,
|
||||
const void* p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
|
||||
}
|
||||
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
typename GridwiseGemm::RsGridPointer p_rs_grid_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
|
||||
@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
static constexpr auto BEnableLds =
|
||||
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
|
||||
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc =
|
||||
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
|
||||
using BGridDesc =
|
||||
@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer_;
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_i32;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_i32;
|
||||
|
||||
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
|
||||
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
|
||||
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
|
||||
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
|
||||
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
|
||||
}
|
||||
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
|
||||
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
|
||||
array_convert(conv_filter_strides_i32, conv_filter_strides);
|
||||
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
|
||||
array_convert(input_left_pads_i32, input_left_pads);
|
||||
array_convert(input_right_pads_i32, input_right_pads);
|
||||
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths_i32,
|
||||
a_g_n_c_wis_strides_i32,
|
||||
b_g_k_c_xs_lengths_i32,
|
||||
b_g_k_c_xs_strides_i32,
|
||||
ds_g_n_k_wos_lengths_i32,
|
||||
ds_g_n_k_wos_strides_i32,
|
||||
e_g_n_k_wos_lengths_i32,
|
||||
e_g_n_k_wos_strides_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using GemmToConvFwdTransformer =
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
|
||||
b_g_k_c_xs_lengths[I2] = C;
|
||||
c_g_n_k_wos_lengths[I1] = N;
|
||||
|
||||
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
|
||||
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
|
||||
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
|
||||
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
|
||||
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
|
||||
|
||||
@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1>
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
private:
|
||||
@@ -46,10 +47,10 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
{
|
||||
const long_index_t a_element_space_size =
|
||||
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
|
||||
@@ -59,7 +60,7 @@ struct TransformConvFwdToGemm
|
||||
c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const index_t N = a_g_n_c_wis_lengths[I1];
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if(element_space_size > TwoGB)
|
||||
{
|
||||
@@ -70,7 +71,7 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
// Find least divisor of N larger than element_space_size / TwoGB
|
||||
// Iterate up to sqrt(N). There are no divisors above this value.
|
||||
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
|
||||
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
|
||||
least_divisor++)
|
||||
{
|
||||
if(N % least_divisor == 0)
|
||||
@@ -98,6 +99,53 @@ struct TransformConvFwdToGemm
|
||||
public:
|
||||
__host__ __device__ constexpr TransformConvFwdToGemm() {}
|
||||
|
||||
template <typename TransformConvFwdToGemmBase>
|
||||
__host__ __device__
|
||||
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
|
||||
: N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
|
||||
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
|
||||
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
|
||||
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
|
||||
Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
|
||||
Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
|
||||
Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
|
||||
Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
|
||||
Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
|
||||
X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
|
||||
K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
|
||||
C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
|
||||
DiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DiStride_)},
|
||||
HiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HiStride_)},
|
||||
WiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WiStride_)},
|
||||
DoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DoStride_)},
|
||||
HoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HoStride_)},
|
||||
WoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WoStride_)},
|
||||
XStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.XStride_)},
|
||||
CStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorA_)},
|
||||
CStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorB_)},
|
||||
KStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorB_)},
|
||||
KStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorC_)},
|
||||
NStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorA_)},
|
||||
NStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorC_)},
|
||||
GStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorA_)},
|
||||
GStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorB_)},
|
||||
GStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorC_)},
|
||||
ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
|
||||
ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
|
||||
ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
|
||||
ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
|
||||
ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
|
||||
ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
|
||||
InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
|
||||
InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
|
||||
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
|
||||
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
|
||||
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
|
||||
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
typename ConvSpatialDimsType,
|
||||
index_t NDim = NDimSpatial,
|
||||
@@ -126,6 +174,8 @@ struct TransformConvFwdToGemm
|
||||
DiStride_{I1},
|
||||
HiStride_{I1},
|
||||
WiStride_{a_g_n_c_wis_strides[I3]},
|
||||
DoStride_{I1},
|
||||
HoStride_{I1},
|
||||
WoStride_{c_g_n_k_wos_strides[I3]},
|
||||
XStride_{b_g_k_c_xs_strides[I3]},
|
||||
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
|
||||
@@ -133,6 +183,7 @@ struct TransformConvFwdToGemm
|
||||
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
|
||||
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
|
||||
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
|
||||
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
|
||||
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
|
||||
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
|
||||
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
|
||||
@@ -150,10 +201,10 @@ struct TransformConvFwdToGemm
|
||||
InRightPadW_{input_right_pads[I0]},
|
||||
ZYX_{X_}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
@@ -164,7 +215,6 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
NDoHoWo_ = N_ * Wo_;
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -195,6 +245,8 @@ struct TransformConvFwdToGemm
|
||||
DiStride_{I1},
|
||||
HiStride_{a_g_n_c_wis_strides[I3]},
|
||||
WiStride_{a_g_n_c_wis_strides[I4]},
|
||||
DoStride_{I1},
|
||||
HoStride_{c_g_n_k_wos_strides[I3]},
|
||||
WoStride_{c_g_n_k_wos_strides[I4]},
|
||||
XStride_{b_g_k_c_xs_strides[I4]},
|
||||
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
|
||||
@@ -202,6 +254,7 @@ struct TransformConvFwdToGemm
|
||||
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
|
||||
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
|
||||
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
|
||||
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
|
||||
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
|
||||
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
|
||||
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
|
||||
@@ -219,10 +272,10 @@ struct TransformConvFwdToGemm
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
@@ -233,7 +286,6 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
NDoHoWo_ = N_ * Ho_ * Wo_;
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -264,6 +316,8 @@ struct TransformConvFwdToGemm
|
||||
DiStride_{a_g_n_c_wis_strides[I3]},
|
||||
HiStride_{a_g_n_c_wis_strides[I4]},
|
||||
WiStride_{a_g_n_c_wis_strides[I5]},
|
||||
DoStride_{c_g_n_k_wos_strides[I3]},
|
||||
HoStride_{c_g_n_k_wos_strides[I4]},
|
||||
WoStride_{c_g_n_k_wos_strides[I5]},
|
||||
XStride_{b_g_k_c_xs_strides[I5]},
|
||||
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
|
||||
@@ -271,6 +325,7 @@ struct TransformConvFwdToGemm
|
||||
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
|
||||
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
|
||||
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
|
||||
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
|
||||
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
|
||||
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
|
||||
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
|
||||
@@ -288,10 +343,10 @@ struct TransformConvFwdToGemm
|
||||
InRightPadW_{input_right_pads[I2]},
|
||||
ZYX_{Z_ * Y_ * X_}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
@@ -302,7 +357,122 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
|
||||
}
|
||||
|
||||
__host__ bool AreDescriptorsSmallerThan2GB() const
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const long_index_t in_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
|
||||
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
|
||||
const long_index_t out_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
|
||||
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
|
||||
|
||||
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
|
||||
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
|
||||
|
||||
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
|
||||
}
|
||||
|
||||
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
|
||||
CDataType* c_grid_ptr_base) const
|
||||
{
|
||||
// Create copies
|
||||
auto conv_to_gemm_transformer_left = *this;
|
||||
auto conv_to_gemm_transformer_right = *this;
|
||||
IndexType a_right_offset = 0;
|
||||
IndexType c_right_offset = 0;
|
||||
// Calculate real filter size
|
||||
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
|
||||
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
|
||||
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
|
||||
// Calculate start position in input for right tensor
|
||||
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
|
||||
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
|
||||
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
|
||||
// Calculate last position in input for left tensor
|
||||
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
|
||||
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
|
||||
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
|
||||
// Allow to split if whole left padding will be in left tensor and right padding in right
|
||||
// tensor
|
||||
const bool is_possible_to_split_d = Do_ != 1 &&
|
||||
di_right_transformer_start_idx > InLeftPadD_ &&
|
||||
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
|
||||
const bool is_possible_to_split_h = Ho_ != 1 &&
|
||||
hi_right_transformer_start_idx > InLeftPadH_ &&
|
||||
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
|
||||
const bool is_possible_to_split_w = Wo_ != 1 &&
|
||||
wi_right_transformer_start_idx > InLeftPadW_ &&
|
||||
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
|
||||
|
||||
if(is_possible_to_split_d)
|
||||
{
|
||||
// Apply new sizes
|
||||
// Split output on half
|
||||
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
|
||||
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
|
||||
// Assign left padding to left convolution
|
||||
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
|
||||
// Assign right padding to right convolution
|
||||
conv_to_gemm_transformer_left.InRightPadD_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
|
||||
// Calculate new input size
|
||||
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.Di_ =
|
||||
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
|
||||
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
|
||||
;
|
||||
// Calcualte offsets
|
||||
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
|
||||
c_right_offset = (Do_ / 2) * DoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_h)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
|
||||
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadH_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
|
||||
|
||||
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.Hi_ =
|
||||
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
|
||||
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
|
||||
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
|
||||
c_right_offset = (Ho_ / 2) * HoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_w)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
|
||||
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadW_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
|
||||
|
||||
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.Wi_ =
|
||||
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
|
||||
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
|
||||
|
||||
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
|
||||
c_right_offset = (Wo_ / 2) * WoStride_;
|
||||
}
|
||||
// Return left transform, right transformer, right offset to Input and right offset to
|
||||
// Output
|
||||
return ck::make_tuple(conv_to_gemm_transformer_left,
|
||||
conv_to_gemm_transformer_right,
|
||||
a_grid_ptr_base + a_right_offset,
|
||||
c_grid_ptr_base + c_right_offset);
|
||||
}
|
||||
|
||||
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
@@ -320,20 +490,27 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
|
||||
make_tuple(WiStride_, CStrideTensorA_));
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -527,20 +704,29 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
|
||||
make_tuple(WiStride_, CStrideTensorA_));
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -759,20 +945,34 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
|
||||
make_tuple(WiStride_, CStrideTensorA_));
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_,
|
||||
DiStride_,
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -1119,45 +1319,70 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename CLayout,
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<NDimSp == 2 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<NDimSp == 3 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
typename std::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNWK>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo_),
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
@@ -1167,7 +1392,7 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo_),
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
@@ -1175,45 +1400,146 @@ struct TransformConvFwdToGemm
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// for output bias
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
|
||||
bool>::type = false>
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename std::enable_if<
|
||||
NDimSp == 2 && (is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNHWK>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(NStrideTensorC_,
|
||||
HoStride_,
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
index_t N_;
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
typename std::enable_if<
|
||||
NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
|
||||
private:
|
||||
const index_t Di_, Hi_, Wi_;
|
||||
const index_t Do_, Ho_, Wo_;
|
||||
const index_t Z_, Y_, X_;
|
||||
const index_t K_, C_;
|
||||
const index_t DiStride_, HiStride_, WiStride_;
|
||||
const index_t WoStride_;
|
||||
const index_t XStride_;
|
||||
const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
|
||||
const index_t NStrideTensorA_;
|
||||
const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
|
||||
const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
|
||||
const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
const index_t InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
const index_t ZYX_;
|
||||
index_t NDoHoWo_;
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(NStrideTensorC_,
|
||||
DoStride_,
|
||||
HoStride_,
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
IndexType N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
IndexType Z_, Y_, X_;
|
||||
IndexType K_, C_;
|
||||
IndexType DiStride_, HiStride_, WiStride_;
|
||||
IndexType DoStride_, HoStride_, WoStride_;
|
||||
IndexType XStride_;
|
||||
IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
|
||||
IndexType NStrideTensorA_, NStrideTensorC_;
|
||||
IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
|
||||
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
|
||||
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
};
|
||||
|
||||
// wrapper class to call member functions on TransformConvToGemm struct at runtime
|
||||
@@ -1230,17 +1556,17 @@ struct TransformConv
|
||||
if(NDimSpatial == 2)
|
||||
{
|
||||
return conv_fwd_to_gemm
|
||||
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
|
||||
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK, 2>();
|
||||
}
|
||||
else if(NDimSpatial == 3)
|
||||
{
|
||||
return conv_fwd_to_gemm
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK, 3>();
|
||||
}
|
||||
else if(NDimSpatial == 1)
|
||||
{
|
||||
return conv_fwd_to_gemm
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
|
||||
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK, 1>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/f8_utils.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/array.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for gfx94x models
|
||||
@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Y, typename X, std::size_t NumElems>
|
||||
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
|
||||
const std::array<X, NumElems>& x)
|
||||
{
|
||||
for(std::size_t i = 0; i < NumElems; i++)
|
||||
{
|
||||
y[i] = type_convert<Y>(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y, typename X, index_t NumElems>
|
||||
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
|
||||
{
|
||||
for(std::size_t i = 0; i < NumElems; i++)
|
||||
{
|
||||
y[i] = type_convert<Y>(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
Reference in New Issue
Block a user