Add support for NGCHW in grouped conv bwd wei (#1491)

* Add support for NGCHW in grouped conv bwd wei

* Comments fixes

* navi fixes

* Update function names
This commit is contained in:
Bartłomiej Kocot
2024-09-03 10:52:03 +02:00
committed by GitHub
parent a9b170b541
commit 73b67f290f
24 changed files with 893 additions and 89 deletions

View File

@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return false;
if constexpr(!((NDimSpatial == 1 &&
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
(is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 2 &&
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 3 &&
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))))
{
return false;
}

View File

@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}

View File

@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t NumGroupsToMerge = 1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
index_t TransposeTransferSrcScalarPerVector = 1,
index_t TransposeTransferDstScalarPerVector = 1>
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using BDataType = InDataType;
using EDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation;
using BElementwiseOperation = InElementwiseOperation;
using CDEElementwiseOperation = WeiElementwiseOperation;
@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch)[I2];
}
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& HiStride = g_n_c_wis_strides[3];
const index_t& WiStride = g_n_c_wis_strides[4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& DiStride = g_n_c_wis_strides[3];
const index_t& HiStride = g_n_c_wis_strides[4];
const index_t& WiStride = g_n_c_wis_strides[5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
using InputTransposeDescType =
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
using OutputTransposeDescType =
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA,
ComputeTypeB>;
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise =
using GridwiseElementwiseCast =
GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>,
Tuple<CElementwiseGridDesc_M_N>,
Tuple<const AccDataType*>,
@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1,
I1>;
using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<InputTransposeDescType>,
Tuple<OutputTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<TransposeTransferSrcScalarPerVector>,
Sequence<TransposeTransferDstScalarPerVector>,
I1,
I0>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
b_g_n_c_wis_strides;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
a_g_n_k_wos_strides;
// NGKHW - transpose needed
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
b_g_n_c_wis_strides_transposed[I2] = I1;
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
a_g_n_k_wos_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] =
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
}
else if constexpr(NDimSpatial == 3)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
input_spatial_lengths_[I2] * Conv_G_ *
Conv_K_;
a_g_n_k_wos_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
}
}
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
@@ -553,11 +750,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ce_grid_desc_m_n_,
GridwiseGemm::CalculateMBlock(GemmM),
GridwiseGemm::CalculateNBlock(GemmN));
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
a_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
b_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
b_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
}
std::size_t GetWorkspaceSizeBytes() const
{
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
else
{
return GetWorkspaceETensorSizeBytes();
}
}
const ADataType* p_a_grid_;
@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_;
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
p_b_grid =
type_convert<const BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
}
// nullptr for output, will be set after workspace set
typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_,
arg.p_b_grid_,
p_c_grid,
GemmM,
GemmN,
GemmK,
I0,
I0,
I0,
arg.k_batch_};
typename GridwiseGemm::Argument gemm_arg{
p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(
gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
};
const auto Run = [&](const auto& kernel) {
@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
auto launch_elementwise_kernel = [&]() {
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std::array<index_t, I1> in_out_batch_strides = {
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides);
};
float avg_time = RunGemmV3(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t grid_size_b =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
BDataType* p_b_out_grid =
type_convert<BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size_a + grid_size_b),
dim3(BlockSize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
grid_size_a);
}
avg_time += RunGemmV3(arg, stream_config);
avg_time += launch_elementwise_kernel();
return avg_time;
}
@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
return false;
}
if constexpr(NDimSpatial == 1)
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
{
return false;
}
if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
{
return false;
}
if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
{
return false;
}
}
return true;
}
@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge
<< ">";
<< NumGroupsToMerge;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVector: "
<< TransposeTransferSrcScalarPerVector <<", "
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
}
str << ">";
// clang-format on
return str.str();

View File

@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return false;
}
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}

View File

@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}

View File

@@ -925,7 +925,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
@@ -941,7 +941,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
@@ -960,7 +960,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>() &&
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;

View File

@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}

View File

@@ -12,7 +12,7 @@ namespace device {
// 1d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NWGK_GKXC_NWGC()
constexpr bool is_NWGC_GKXC_NWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNWK_GKXC_GNWC()
constexpr bool is_GNWC_GKXC_GNWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGK_GKYXC_NHWGC()
constexpr bool is_NHWGC_GKYXC_NHWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNHWK_GKYXC_GNHWC()
constexpr bool is_GNHWC_GKYXC_GNHWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_GKYXC_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
// 3d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC()
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
{
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>();
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC()
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
{
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>();
return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
{
return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static constexpr const char* name = "NDHWGC";
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct NGCW : public BaseTensorLayout
{
static constexpr const char* name = "NGCW";
};
struct NGCHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCHW";
};
struct NGCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCDHW";
};
// input tensor
// strided layout
struct G_NW_C : public BaseTensorLayout
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static constexpr const char* name = "NDHWGK";
};
struct NGKW : public BaseTensorLayout
{
static constexpr const char* name = "NGKW";
};
struct NGKHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKHW";
};
struct NGKDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKDHW";
};
// output tensor
// strided layout
struct G_NW_K : public BaseTensorLayout