mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add support for NGCHW in grouped conv bwd wei (#1491)
* Add support for NGCHW in grouped conv bwd wei * Comments fixes * navi fixes * Update function names
This commit is contained in:
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
return false;
|
||||
|
||||
if constexpr(!((NDimSpatial == 1 &&
|
||||
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(NDimSpatial == 2 &&
|
||||
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(NDimSpatial == 3 &&
|
||||
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
|
||||
(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
index_t TransposeTransferSrcScalarPerVector = 1,
|
||||
index_t TransposeTransferDstScalarPerVector = 1>
|
||||
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using BDataType = InDataType;
|
||||
using EDataType = WeiDataType;
|
||||
|
||||
// If NGCHW then ADataType must be equal to BDataType
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
is_same_v<ADataType, BDataType>);
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
using BElementwiseOperation = InElementwiseOperation;
|
||||
using CDEElementwiseOperation = WeiElementwiseOperation;
|
||||
@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
batch)[I2];
|
||||
}
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& HiStride = g_n_c_wis_strides[3];
|
||||
const index_t& WiStride = g_n_c_wis_strides[4];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& DiStride = g_n_c_wis_strides[3];
|
||||
const index_t& HiStride = g_n_c_wis_strides[4];
|
||||
const index_t& WiStride = g_n_c_wis_strides[5];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t DiStride = Hi * Wi * G * C;
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
using InputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using OutputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using GridwiseElementwise =
|
||||
using GridwiseElementwiseCast =
|
||||
GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>,
|
||||
Tuple<CElementwiseGridDesc_M_N>,
|
||||
Tuple<const AccDataType*>,
|
||||
@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
GridwiseElementwise<Tuple<InputTransposeDescType>,
|
||||
Tuple<OutputTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<TransposeTransferSrcScalarPerVector>,
|
||||
Sequence<TransposeTransferDstScalarPerVector>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
b_g_n_c_wis_strides;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
a_g_n_k_wos_strides;
|
||||
|
||||
// NGKHW - transpose needed
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I2] = I1;
|
||||
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I2] = I1;
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] =
|
||||
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
|
||||
input_spatial_lengths_[I2] * Conv_G_ *
|
||||
Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
@@ -553,11 +750,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
ce_grid_desc_m_n_,
|
||||
GridwiseGemm::CalculateMBlock(GemmM),
|
||||
GridwiseGemm::CalculateNBlock(GemmN));
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
a_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
|
||||
b_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
b_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_;
|
||||
|
||||
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
p_b_grid =
|
||||
type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
sizeof(BDataType);
|
||||
}
|
||||
|
||||
// nullptr for output, will be set after workspace set
|
||||
typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
p_c_grid,
|
||||
GemmM,
|
||||
GemmN,
|
||||
GemmK,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
arg.k_batch_};
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
|
||||
@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
|
||||
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
std::array<index_t, I1> in_out_batch_strides = {
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
|
||||
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
in_out_batch_strides);
|
||||
};
|
||||
|
||||
float avg_time = RunGemmV3(arg, stream_config);
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_a =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
const index_t grid_size_b =
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
|
||||
arg.b_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
BDataType* p_b_out_grid =
|
||||
type_convert<BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
sizeof(BDataType);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size_a + grid_size_b),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.b_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.b_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(arg.p_b_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
make_tuple(p_b_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_,
|
||||
element_wise::PassThrough{},
|
||||
grid_size_a);
|
||||
}
|
||||
|
||||
avg_time += RunGemmV3(arg, stream_config);
|
||||
avg_time += launch_elementwise_kernel();
|
||||
return avg_time;
|
||||
}
|
||||
@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< NumGroupsToMerge
|
||||
<< ">";
|
||||
<< NumGroupsToMerge;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
str << ", TransposeTransferSrcScalarPerVector: "
|
||||
<< TransposeTransferSrcScalarPerVector <<", "
|
||||
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
|
||||
}
|
||||
|
||||
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -925,7 +925,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -941,7 +941,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -960,7 +960,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
// If not possible, check access per G
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
|
||||
is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>() &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
G % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace device {
|
||||
|
||||
// 1d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NWGK_GKXC_NWGC()
|
||||
constexpr bool is_NWGC_GKXC_NWGK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNWK_GKXC_GNWC()
|
||||
constexpr bool is_GNWC_GKXC_GNWK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
|
||||
}
|
||||
// 2d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NHWGK_GKYXC_NHWGC()
|
||||
constexpr bool is_NHWGC_GKYXC_NHWGK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNHWK_GKYXC_GNHWC()
|
||||
constexpr bool is_GNHWC_GKYXC_GNHWK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCHW_GKYXC_NGKHW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
|
||||
}
|
||||
// 3d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
|
||||
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
|
||||
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC()
|
||||
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
|
||||
{
|
||||
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>();
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC()
|
||||
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
|
||||
{
|
||||
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>();
|
||||
return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
|
||||
{
|
||||
return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
|
||||
static constexpr const char* name = "NDHWGC";
|
||||
};
|
||||
|
||||
// input tensor
|
||||
// packed NGCW/NGCHW/NGCDHW
|
||||
struct NGCW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGCW";
|
||||
};
|
||||
|
||||
struct NGCHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGCHW";
|
||||
};
|
||||
|
||||
struct NGCDHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGCDHW";
|
||||
};
|
||||
|
||||
// input tensor
|
||||
// strided layout
|
||||
struct G_NW_C : public BaseTensorLayout
|
||||
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
|
||||
static constexpr const char* name = "NDHWGK";
|
||||
};
|
||||
|
||||
struct NGKW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGKW";
|
||||
};
|
||||
|
||||
struct NGKHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGKHW";
|
||||
};
|
||||
|
||||
struct NGKDHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGKDHW";
|
||||
};
|
||||
|
||||
// output tensor
|
||||
// strided layout
|
||||
struct G_NW_K : public BaseTensorLayout
|
||||
|
||||
Reference in New Issue
Block a user