mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Add support for NGCHW in grouped conv bwd wei (#1491)
* Add support for NGCHW in grouped conv bwd wei * Comments fixes * navi fixes * Update function names
This commit is contained in:
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
return false;
|
||||
|
||||
if constexpr(!((NDimSpatial == 1 &&
|
||||
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(NDimSpatial == 2 &&
|
||||
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>())) ||
|
||||
(NDimSpatial == 3 &&
|
||||
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
|
||||
(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
index_t TransposeTransferSrcScalarPerVector = 1,
|
||||
index_t TransposeTransferDstScalarPerVector = 1>
|
||||
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using BDataType = InDataType;
|
||||
using EDataType = WeiDataType;
|
||||
|
||||
// If NGCHW then ADataType must be equal to BDataType
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
is_same_v<ADataType, BDataType>);
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
using BElementwiseOperation = InElementwiseOperation;
|
||||
using CDEElementwiseOperation = WeiElementwiseOperation;
|
||||
@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
batch)[I2];
|
||||
}
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& HiStride = g_n_c_wis_strides[3];
|
||||
const index_t& WiStride = g_n_c_wis_strides[4];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& DiStride = g_n_c_wis_strides[3];
|
||||
const index_t& HiStride = g_n_c_wis_strides[4];
|
||||
const index_t& WiStride = g_n_c_wis_strides[5];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t DiStride = Hi * Wi * G * C;
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
using InputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using OutputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using GridwiseElementwise =
|
||||
using GridwiseElementwiseCast =
|
||||
GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>,
|
||||
Tuple<CElementwiseGridDesc_M_N>,
|
||||
Tuple<const AccDataType*>,
|
||||
@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
GridwiseElementwise<Tuple<InputTransposeDescType>,
|
||||
Tuple<OutputTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<TransposeTransferSrcScalarPerVector>,
|
||||
Sequence<TransposeTransferDstScalarPerVector>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
b_g_n_c_wis_strides;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
a_g_n_k_wos_strides;
|
||||
|
||||
// NGKHW - transpose needed
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I2] = I1;
|
||||
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I2] = I1;
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] =
|
||||
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
|
||||
input_spatial_lengths_[I2] * Conv_G_ *
|
||||
Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
@@ -553,11 +750,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
ce_grid_desc_m_n_,
|
||||
GridwiseGemm::CalculateMBlock(GemmM),
|
||||
GridwiseGemm::CalculateNBlock(GemmN));
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
a_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
|
||||
b_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
b_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_;
|
||||
|
||||
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
p_b_grid =
|
||||
type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
sizeof(BDataType);
|
||||
}
|
||||
|
||||
// nullptr for output, will be set after workspace set
|
||||
typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
p_c_grid,
|
||||
GemmM,
|
||||
GemmN,
|
||||
GemmK,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
arg.k_batch_};
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
|
||||
@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
|
||||
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
std::array<index_t, I1> in_out_batch_strides = {
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
|
||||
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
in_out_batch_strides);
|
||||
};
|
||||
|
||||
float avg_time = RunGemmV3(arg, stream_config);
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_a =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
const index_t grid_size_b =
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
|
||||
arg.b_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
BDataType* p_b_out_grid =
|
||||
type_convert<BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
sizeof(BDataType);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size_a + grid_size_b),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.b_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.b_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(arg.p_b_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
make_tuple(p_b_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_,
|
||||
element_wise::PassThrough{},
|
||||
grid_size_a);
|
||||
}
|
||||
|
||||
avg_time += RunGemmV3(arg, stream_config);
|
||||
avg_time += launch_elementwise_kernel();
|
||||
return avg_time;
|
||||
}
|
||||
@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< NumGroupsToMerge
|
||||
<< ">";
|
||||
<< NumGroupsToMerge;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
str << ", TransposeTransferSrcScalarPerVector: "
|
||||
<< TransposeTransferSrcScalarPerVector <<", "
|
||||
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
|
||||
}
|
||||
|
||||
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -925,7 +925,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -941,7 +941,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -960,7 +960,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
// If not possible, check access per G
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
|
||||
is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>() &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
G % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace device {
|
||||
|
||||
// 1d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NWGK_GKXC_NWGC()
|
||||
constexpr bool is_NWGC_GKXC_NWGK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNWK_GKXC_GNWC()
|
||||
constexpr bool is_GNWC_GKXC_GNWK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
|
||||
}
|
||||
// 2d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NHWGK_GKYXC_NHWGC()
|
||||
constexpr bool is_NHWGC_GKYXC_NHWGK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNHWK_GKYXC_GNHWC()
|
||||
constexpr bool is_GNHWC_GKYXC_GNHWK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCHW_GKYXC_NGKHW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
|
||||
}
|
||||
// 3d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
|
||||
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
|
||||
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC()
|
||||
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
|
||||
{
|
||||
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>();
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC()
|
||||
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
|
||||
{
|
||||
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>();
|
||||
return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
|
||||
{
|
||||
return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
|
||||
static constexpr const char* name = "NDHWGC";
|
||||
};
|
||||
|
||||
// input tensor
|
||||
// packed NGCW/NGCHW/NGCDHW
|
||||
struct NGCW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGCW";
|
||||
};
|
||||
|
||||
struct NGCHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGCHW";
|
||||
};
|
||||
|
||||
struct NGCDHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGCDHW";
|
||||
};
|
||||
|
||||
// input tensor
|
||||
// strided layout
|
||||
struct G_NW_C : public BaseTensorLayout
|
||||
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
|
||||
static constexpr const char* name = "NDHWGK";
|
||||
};
|
||||
|
||||
struct NGKW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGKW";
|
||||
};
|
||||
|
||||
struct NGKHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGKHW";
|
||||
};
|
||||
|
||||
struct NGKDHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NGKDHW";
|
||||
};
|
||||
|
||||
// output tensor
|
||||
// strided layout
|
||||
struct G_NW_K : public BaseTensorLayout
|
||||
|
||||
@@ -41,6 +41,55 @@ __global__ void
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InAGridDescTuple,
|
||||
typename InBGridDescTuple,
|
||||
typename OutAGridDescTuple,
|
||||
typename OutBGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMapA,
|
||||
typename Block2TileMapB,
|
||||
typename ElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
|
||||
const InBGridDescTuple in_grid_desc_tuple_b,
|
||||
const OutAGridDescTuple out_grid_desc_tuple_a,
|
||||
const OutBGridDescTuple out_grid_desc_tuple_b,
|
||||
const InDataTypePointerTuple p_in_global_tuple_a,
|
||||
const InDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_b,
|
||||
const Block2TileMapA block_2_tile_map_a,
|
||||
const Block2TileMapB block_2_tile_map_b,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const index_t a_grid_size)
|
||||
{
|
||||
if(get_block_1d_id() < a_grid_size)
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_tuple_a,
|
||||
p_out_global_tuple_a,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
get_block_1d_id());
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_tuple_b,
|
||||
p_out_global_tuple_b,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
get_block_1d_id() - a_grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
@@ -133,7 +182,8 @@ struct GridwiseElementwise
|
||||
const InDataTypePointerTuple& p_in_global_tuple,
|
||||
const OutDataTypePointerTuple& p_out_global_tuple,
|
||||
const Block2TileMap& block_2_tile_map,
|
||||
const ElementwiseOperation& elementwise_op)
|
||||
const ElementwiseOperation& elementwise_op,
|
||||
const index_t block_id = get_block_1d_id())
|
||||
{
|
||||
|
||||
constexpr auto src_datas = generate_tuple(
|
||||
@@ -169,7 +219,7 @@ struct GridwiseElementwise
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
|
||||
|
||||
const index_t m0_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
|
||||
|
||||
@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
using NGKW = ck::tensor_layout::convolution::NGKW;
|
||||
using NGKHW = ck::tensor_layout::convolution::NGKHW;
|
||||
using NGKDHW = ck::tensor_layout::convolution::NGKDHW;
|
||||
|
||||
//
|
||||
using NWGC = ck::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
using NGCW = ck::tensor_layout::convolution::NGCW;
|
||||
using NGCHW = ck::tensor_layout::convolution::NGCHW;
|
||||
using NGCDHW = ck::tensor_layout::convolution::NGCDHW;
|
||||
|
||||
//
|
||||
using G_K = ck::tensor_layout::convolution::G_K;
|
||||
using GK_Tuple = ck::Tuple<G_K>;
|
||||
|
||||
@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// NGCHW requires transpose, we use vector loads and stores params for them
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 1, 2>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 1>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8 ,1>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NGKDHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
|
||||
{
|
||||
return {0, 1, 2, 3};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKW>)
|
||||
{
|
||||
return {1, 0, 2, 3};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCHW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKHW>)
|
||||
{
|
||||
return {1, 0, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCDHW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
return {1, 0, 2, 3, 4, 5};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
|
||||
@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCW> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCHW> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
|
||||
@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKW> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKHW> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
|
||||
|
||||
@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,6 +8,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,6 +16,7 @@ enum struct ConvLayout
|
||||
GNCHW_GKCYX_GNKHW, // 0
|
||||
GNHWC_GKYXC_GNHWK, // 1
|
||||
NHWGC_GKYXC_NHWGK, // 2
|
||||
NGCHW_GKYXC_NGKHW, // 3
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -178,6 +179,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
@@ -224,6 +232,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
|
||||
@@ -23,6 +23,26 @@ def run_ck_profiler_cmd(cmd):
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
|
||||
args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 3
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
exit(1)
|
||||
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
|
||||
args.in_layout == "NDHWC":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 2
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.layout = 1
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
exit(1)
|
||||
|
||||
|
||||
def parse_data_type(args):
|
||||
if args.data_type == "fp32":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
@@ -79,8 +99,7 @@ def add_conv_params_to_cmd(args, cmd):
|
||||
def run_ck_grouped_conv_fwd(args):
|
||||
args.ck_profier_op = "grouped_conv_fwd"
|
||||
parse_data_type(args)
|
||||
# default for MIOpen NHWGC
|
||||
args.layout = 1
|
||||
parse_layouts(args)
|
||||
# use int32 by default
|
||||
args.index_type = 0
|
||||
|
||||
@@ -99,8 +118,7 @@ def run_ck_grouped_conv_fwd(args):
|
||||
def run_ck_grouped_conv_bwd_data(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_data"
|
||||
parse_data_type(args)
|
||||
# default for MIOpen NHWGC
|
||||
args.layout = 1
|
||||
parse_layouts(args)
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
@@ -117,8 +135,7 @@ def run_ck_grouped_conv_bwd_data(args):
|
||||
def run_ck_grouped_conv_bwd_weight(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_weight"
|
||||
parse_data_type(args)
|
||||
# default for MIOpen NHWGC
|
||||
args.layout = 2
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = -1
|
||||
|
||||
@@ -181,8 +198,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
"-I",
|
||||
default=-1,
|
||||
type=int,
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
|
||||
)
|
||||
|
||||
@@ -66,6 +66,12 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// Skip due to the lack of kernels for NGCDHW
|
||||
if constexpr(std::is_same_v<InLayout, NGCW> || std::is_same_v<InLayout, NGCHW> ||
|
||||
std::is_same_v<InLayout, NGCDHW>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -139,7 +145,8 @@ using KernelTypes2d = ::testing::Types<
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
|
||||
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>>;
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>>;
|
||||
using KernelTypes3d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
@@ -148,7 +155,8 @@ using KernelTypes3d = ::testing::Types<
|
||||
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
|
||||
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NGCDHW, GKZYXC, NGKDHW, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
|
||||
|
||||
Reference in New Issue
Block a user