Add support for GKCYX grouped conv weight (#2023)

* Grouped conv bwd weight GKCYX support

* fix and changelog

* fix

* fix

* fixes

* comments

* fix
This commit is contained in:
Bartłomiej Kocot
2025-04-02 23:59:49 +02:00
committed by GitHub
parent e5ad48a784
commit 2ccf914888
101 changed files with 1004 additions and 356 deletions

View File

@@ -218,8 +218,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using EDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation;
@@ -376,6 +376,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
@@ -452,6 +458,28 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
I1,
I1>;
// NPerBlock is used for the first dim which is store dimension
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
// CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so
// it is more flexible to use this dim for store dimension with such scalar
// per vector.
using GridwiseElementwiseWeightTransposeCast =
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
Tuple<GKCYXTransposeDescType>,
Tuple<const AccDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<0, 1>,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
Sequence<1>,
I1,
I0>;
using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
@@ -533,12 +561,15 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
const auto descs =
conv_to_gemm_transformer_v2
@@ -550,7 +581,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
@@ -580,29 +611,21 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_)[I2];
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{
ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ce_grid_desc_m_n_,
GridwiseGemm::CalculateMBlock(GemmM),
GridwiseGemm::CalculateNBlock(GemmN));
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
@@ -618,17 +641,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
}
elementwise_block_2_ctile_map_ =
is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>()
? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0),
e_in_transpose_desc_.GetLength(I1)}
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
ce_grid_desc_m_n_.GetLength(I1)};
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
// Align to 128B
return math::integer_divide_ceil(
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
128;
}
std::size_t GetWorkspaceBTensorSizeBytes() const
@@ -638,14 +679,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
// Align to 128B
return math::integer_divide_ceil(sizeof(AccDataType) *
ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_,
128) *
128;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
// 1. We need to transpose A and B for NGCHW and NGKHW layouts
// 2. If C format is GKCYX then tranpose during second stage.
// If C format is GKYXC then just perform second stage.
// Due to the fact that E workspace is always needed, we
// allocate them as the first part of the workspace.
// [EWorkspace, AWorkspace, BWorkspace]
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
@@ -672,6 +722,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
GKYXCTransposeDescType e_in_transpose_desc_;
GKCYXTransposeDescType e_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
@@ -728,11 +780,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
p_b_grid =
type_convert<const BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
@@ -1373,41 +1425,72 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float avg_time = 0.f;
auto launch_elementwise_kernel = [&]() {
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.ce_elementwise_grid_desc_m_n_) *
arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = {
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation,
I1,
I1>;
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.e_in_transpose_desc_);
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
arg.cde_element_op_,
arg.Conv_G_,
in_out_batch_strides,
in_out_batch_strides);
const auto kernel = kernel_elementwise<GridwiseElementwiseWeightTransposeCast,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
arg.cde_element_op_);
}
else
{
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.ce_elementwise_grid_desc_m_n_) *
arg.Conv_G_;
const auto kernel =
kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation,
I1,
I1>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
arg.cde_element_op_,
arg.Conv_G_,
in_out_batch_strides,
in_out_batch_strides);
}
};
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -1417,7 +1500,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
BDataType* p_b_out_grid =
type_convert<BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
@@ -1514,7 +1597,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -1522,7 +1605,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -1597,8 +1680,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
{
@@ -1767,8 +1850,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVector: "
<< TransposeTransferSrcScalarPerVector <<", "
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;

View File

@@ -165,8 +165,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using CDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation;
@@ -301,7 +301,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
static constexpr index_t TransposeTransferSrcScalarPerVectorAligned =
std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector);
@@ -314,13 +314,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
using GridwiseElementwiseTranspose =
using GridwiseInOutTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapTranspose,
element_wise::PassThrough,
BlockSize,
MPerBlock,
@@ -333,6 +339,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
I1,
I0>;
// NPerBlock is used for the first dim which is store dimension
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
Tuple<GKCYXTransposeDescType>,
Tuple<const CDataType*>,
Tuple<CDataType*>,
Block2TileMapTranspose,
element_wise::PassThrough,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
Sequence<1>,
I1,
I0>;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType,
@@ -452,13 +478,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -469,7 +497,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
@@ -487,12 +515,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
@@ -503,8 +526,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
@@ -520,31 +543,33 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes();
// Align to 128B
return math::integer_divide_ceil(
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
128;
}
else
{
@@ -552,6 +577,41 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
// Align to 128B
return math::integer_divide_ceil(
sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) *
128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize();
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
@@ -562,12 +622,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_;
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
GKYXCTransposeDescType e_in_transpose_desc_;
GKCYXTransposeDescType e_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
@@ -621,9 +684,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
CDataType* p_e_grid = arg.p_c_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_e_grid =
type_convert<CDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(CDataType);
}
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -640,8 +713,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
GridwiseElementwiseTranspose,
auto kernel_transpose = kernel_elementwise_dual<GridwiseInOutTranspose,
GridwiseInOutTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
@@ -650,8 +723,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
Block2TileMapTranspose,
Block2TileMapTranspose,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
@@ -698,24 +771,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch<>,
has_main_loop>;
avg_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
const auto clear_workspace = [&]() {
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
hip_check_error(hipMemsetAsync(p_e_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
}
};
avg_time += launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
if(has_main_k0_block_loop)
@@ -726,6 +811,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
launch_kernel(integral_constant<bool, false>{});
}
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_e =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const CDataType* p_e_in_grid = static_cast<const CDataType*>(p_e_grid);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseWeightTranspose,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<const CDataType*>,
ck::Tuple<CDataType*>,
Block2TileMapTranspose,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size_e),
dim3(BlockSize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_in_grid),
make_tuple(arg.p_c_grid_),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
@@ -763,7 +880,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -772,7 +889,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -810,8 +927,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0)
{
@@ -980,8 +1097,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVectorAligned: "
<< TransposeTransferSrcScalarPerVectorAligned <<", "
<< "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned;

View File

@@ -502,6 +502,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
// NPerBlock is used for the first and second dim which to use
// CDEBlockTransferScalarPerVector_NPerBlock for load and store during
// transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to
// NPerBlock so it is more flexible to use this dim for load store dimension
// with such scalar per vector.
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,