mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add support for NGCHW in grouped conv fwd (#1499)
* Support NGCHW in grouped conv fwd * Remove not needed variable * Fixes
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
@@ -22,7 +23,6 @@
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
KPerBlock / K1Number,
|
||||
ConvBackwardWeightSpecialization>{};
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NDimSpatial,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
batch)[I2];
|
||||
}
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& HiStride = g_n_c_wis_strides[3];
|
||||
const index_t& WiStride = g_n_c_wis_strides[4];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& DiStride = g_n_c_wis_strides[3];
|
||||
const index_t& HiStride = g_n_c_wis_strides[4];
|
||||
const index_t& WiStride = g_n_c_wis_strides[5];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t DiStride = Hi * Wi * G * C;
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
using InputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using OutputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
I1>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
GridwiseElementwise<Tuple<InputTransposeDescType>,
|
||||
Tuple<OutputTransposeDescType>,
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
b_g_n_c_wis_strides;
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
a_g_n_k_wos_strides;
|
||||
|
||||
// NGKHW - transpose needed
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I2] = I1;
|
||||
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I2] = I1;
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] =
|
||||
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
|
||||
input_spatial_lengths_[I2] * Conv_G_ *
|
||||
Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
}
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
a_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
|
||||
b_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
b_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_;
|
||||
|
||||
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
sizeof(BDataType);
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
// NGCHW is not supported for multiAB
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
|
||||
!(isMultiA || isMultiB));
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
NDimSpatial,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<const EDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I0,
|
||||
I1>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_bs_grid_{},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
num_group_{a_g_n_c_wis_lengths_[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
conv_N_per_block_{conv_to_gemm_transformer_.N_},
|
||||
a_grid_desc_m_k_{
|
||||
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
|
||||
@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_groups_ for multiple AB
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_(i) =
|
||||
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
|
||||
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
|
||||
|
||||
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
|
||||
// type is not tuple)
|
||||
@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// in case of MultiA is false but isMultiB is true
|
||||
// BatchStrideA_ is not tuple.
|
||||
compute_ptr_offset_of_n_.BatchStrideA_(i) =
|
||||
a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiB and not MultiA then p_as is single pointer
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as);
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
}
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_groups_ for multiple AB
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_(i) =
|
||||
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
|
||||
b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
|
||||
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
else
|
||||
{
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
|
||||
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ =
|
||||
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
|
||||
@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge;
|
||||
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
ds_g_n_k_wos_strides_[i],
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_};
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
|
||||
});
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ =
|
||||
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// populate desc for Ds/E
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
a_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
}
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
|
||||
p_a_grid, // Pass just A descriptor instead of tuple
|
||||
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.p_as_grid_.At(I0)),
|
||||
make_tuple(p_a_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_out_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_in_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_out_grid),
|
||||
make_tuple(p_e_in_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC>)
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
// Check access per C
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
// If not possible, check access per G
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
|
||||
(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
|
||||
G % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
|
||||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
|
||||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
|
||||
is_same_v<ELayout, ctc::NDHWGK>)
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
NDimSpatial,
|
||||
MPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
|
||||
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<const EDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I0,
|
||||
I1>;
|
||||
|
||||
static auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: p_a_grid_{},
|
||||
p_b_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
num_group_{a_g_n_c_wis_lengths_[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
conv_N_per_block_{conv_to_gemm_transformer_.N_},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
|
||||
@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// A/B/E Batch/N Stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_a_grid_ = static_cast<const ADataType*>(p_as);
|
||||
p_b_grid_ = static_cast<const BDataType*>(p_bs);
|
||||
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
a_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const BDataType* p_b_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
// block-to-e-tile map
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
arg.p_a_grid_, arg.p_b_grid_, arg.p_e_grid_, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_out_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_in_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_out_grid),
|
||||
make_tuple(p_e_in_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
{
|
||||
@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC>)
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
const index_t C = arg.a_g_n_c_wis_lengths_[2];
|
||||
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
is_same_v<BLayout, ctc::KZYXGC>)
|
||||
|
||||
{
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
|
||||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
|
||||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
|
||||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
|
||||
is_same_v<ELayout, ctc::NDHWGK>)
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
const index_t K = arg.e_g_n_k_wos_lengths_[2];
|
||||
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCW_GKXC_NGKW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
|
||||
}
|
||||
|
||||
// 2d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NHWGC_GKYXC_NHWGK()
|
||||
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
|
||||
{
|
||||
return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user