mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add support for GKCYX grouped conv weight (#2023)
* Grouped conv bwd weight GKCYX support * fix and changelog * fix * fix * fixes * comments * fix
This commit is contained in:
@@ -218,8 +218,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using EDataType = WeiDataType;
|
||||
|
||||
// If NGCHW then ADataType must be equal to BDataType
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
is_same_v<ADataType, BDataType>);
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
@@ -376,6 +376,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKCYXTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKYXCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
@@ -452,6 +458,28 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
I1,
|
||||
I1>;
|
||||
// NPerBlock is used for the first dim which is store dimension
|
||||
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
|
||||
// CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so
|
||||
// it is more flexible to use this dim for store dimension with such scalar
|
||||
// per vector.
|
||||
using GridwiseElementwiseWeightTransposeCast =
|
||||
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
|
||||
Tuple<GKCYXTransposeDescType>,
|
||||
Tuple<const AccDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<0, 1>,
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
Sequence<1>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
@@ -533,12 +561,15 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
@@ -550,7 +581,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
@@ -580,29 +611,21 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads,
|
||||
k_batch_)[I2];
|
||||
|
||||
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{
|
||||
ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
|
||||
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ce_grid_desc_m_n_,
|
||||
GridwiseGemm::CalculateMBlock(GemmM),
|
||||
GridwiseGemm::CalculateNBlock(GemmN));
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
@@ -618,17 +641,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
|
||||
elementwise_block_2_ctile_map_ =
|
||||
is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>()
|
||||
? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0),
|
||||
e_in_transpose_desc_.GetLength(I1)}
|
||||
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
|
||||
ce_grid_desc_m_n_.GetLength(I1)};
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(
|
||||
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
|
||||
128;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
@@ -638,14 +679,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(sizeof(AccDataType) *
|
||||
ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_,
|
||||
128) *
|
||||
128;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
// 1. We need to transpose A and B for NGCHW and NGKHW layouts
|
||||
// 2. If C format is GKCYX then tranpose during second stage.
|
||||
// If C format is GKYXC then just perform second stage.
|
||||
// Due to the fact that E workspace is always needed, we
|
||||
// allocate them as the first part of the workspace.
|
||||
// [EWorkspace, AWorkspace, BWorkspace]
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
@@ -672,6 +722,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
GKYXCTransposeDescType e_in_transpose_desc_;
|
||||
GKCYXTransposeDescType e_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
@@ -728,11 +780,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
|
||||
p_b_grid =
|
||||
type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
@@ -1373,41 +1425,72 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
float avg_time = 0.f;
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
arg.ce_elementwise_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
|
||||
std::array<index_t, I1> in_out_batch_strides = {
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
|
||||
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
I1,
|
||||
I1>;
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_,
|
||||
arg.Conv_G_,
|
||||
in_out_batch_strides,
|
||||
in_out_batch_strides);
|
||||
const auto kernel = kernel_elementwise<GridwiseElementwiseWeightTransposeCast,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
arg.ce_elementwise_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_,
|
||||
arg.Conv_G_,
|
||||
in_out_batch_strides,
|
||||
in_out_batch_strides);
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_a =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -1417,7 +1500,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
|
||||
BDataType* p_b_out_grid =
|
||||
type_convert<BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
@@ -1514,7 +1597,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1522,7 +1605,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1597,8 +1680,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
|
||||
{
|
||||
@@ -1767,8 +1850,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< NumGroupsToMerge;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
str << ", TransposeTransferSrcScalarPerVector: "
|
||||
<< TransposeTransferSrcScalarPerVector <<", "
|
||||
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
|
||||
|
||||
@@ -165,8 +165,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
using CDataType = WeiDataType;
|
||||
|
||||
// If NGCHW then ADataType must be equal to BDataType
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
is_same_v<ADataType, BDataType>);
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
@@ -301,7 +301,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
static constexpr index_t TransposeTransferSrcScalarPerVectorAligned =
|
||||
std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector);
|
||||
@@ -314,13 +314,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKCYXTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKYXCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
using GridwiseInOutTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -333,6 +339,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
// NPerBlock is used for the first dim which is store dimension
|
||||
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
|
||||
using GridwiseElementwiseWeightTranspose =
|
||||
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
|
||||
Tuple<GKCYXTransposeDescType>,
|
||||
Tuple<const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
Sequence<1>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -452,13 +478,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -469,7 +497,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
@@ -487,12 +515,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
@@ -503,8 +526,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
@@ -520,31 +543,33 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes();
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(
|
||||
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
|
||||
128;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -552,6 +577,41 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(
|
||||
sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) *
|
||||
128;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
@@ -562,12 +622,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_;
|
||||
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
|
||||
GKYXCTransposeDescType e_in_transpose_desc_;
|
||||
GKCYXTransposeDescType e_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
|
||||
@@ -621,9 +684,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
CDataType* p_e_grid = arg.p_c_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
p_e_grid =
|
||||
type_convert<CDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(CDataType);
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_a =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -640,8 +713,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
GridwiseElementwiseTranspose,
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseInOutTranspose,
|
||||
GridwiseInOutTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
@@ -650,8 +723,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapTranspose,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
@@ -698,24 +771,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
has_main_loop>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
const auto clear_workspace = [&]() {
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(p_e_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
@@ -726,6 +811,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_e =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const CDataType* p_e_in_grid = static_cast<const CDataType*>(p_e_grid);
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseWeightTranspose,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<const CDataType*>,
|
||||
ck::Tuple<CDataType*>,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size_e),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(arg.p_c_grid_),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
@@ -763,7 +880,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -772,7 +889,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -810,8 +927,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0)
|
||||
{
|
||||
@@ -980,8 +1097,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
str << ", TransposeTransferSrcScalarPerVectorAligned: "
|
||||
<< TransposeTransferSrcScalarPerVectorAligned <<", "
|
||||
<< "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned;
|
||||
|
||||
@@ -502,6 +502,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
// NPerBlock is used for the first and second dim which to use
|
||||
// CDEBlockTransferScalarPerVector_NPerBlock for load and store during
|
||||
// transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to
|
||||
// NPerBlock so it is more flexible to use this dim for load store dimension
|
||||
// with such scalar per vector.
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
|
||||
Reference in New Issue
Block a user