mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add support for GKCYX grouped conv fwd (#2015)
* Add support for GKCYX grouped conv fwd
* fixes
* fix
* changelog
* Fixes
[ROCm/composable_kernel commit: 54c81a1fcf]
This commit is contained in:
@@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
### Added
|
||||
|
||||
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
|
||||
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
@@ -496,11 +496,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides);
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
|
||||
@@ -534,11 +534,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
@@ -1425,11 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
GridwiseElementwiseTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
|
||||
@@ -453,11 +453,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
@@ -641,11 +641,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
GridwiseElementwiseTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
|
||||
@@ -314,8 +314,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
// NGCHW is not supported for multiAB
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
|
||||
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()) ||
|
||||
!(isMultiA || isMultiB));
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
@@ -355,11 +355,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
@@ -373,8 +371,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <typename BLay>
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
|
||||
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
|
||||
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
@@ -387,11 +391,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
@@ -491,6 +493,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using GKCYXTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKYXCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
@@ -511,6 +520,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseWeightTranspose =
|
||||
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
|
||||
Tuple<GKYXCTransposeDescType>,
|
||||
Tuple<const BDataType*>,
|
||||
Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I0,
|
||||
I1>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
@@ -558,14 +585,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
@@ -744,8 +772,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -755,6 +783,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
|
||||
b_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
|
||||
b_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
@@ -764,6 +799,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
@@ -771,25 +808,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -797,6 +822,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
@@ -849,10 +911,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
GKCYXTransposeDescType b_in_transpose_desc_;
|
||||
GKYXCTransposeDescType b_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
@@ -942,14 +1006,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
else
|
||||
{
|
||||
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
|
||||
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
@@ -978,8 +1056,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid, // Pass just A descriptor instead of tuple
|
||||
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
@@ -1009,50 +1087,71 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
const index_t b_grid_size =
|
||||
(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
|
||||
arg.b_in_transpose_desc_)
|
||||
: 0; // Dont run transpose B if not needed
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
|
||||
GridwiseElementwiseWeightTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<const BDataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(a_grid_size + b_grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.b_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.b_out_transpose_desc_),
|
||||
make_tuple(arg.p_as_grid_.At(I0)),
|
||||
make_tuple(arg.p_bs_grid_.At(I0)),
|
||||
make_tuple(p_a_out_grid),
|
||||
make_tuple(p_b_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{});
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_,
|
||||
element_wise::PassThrough{},
|
||||
a_grid_size);
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_out_grid =
|
||||
const EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_in_grid = arg.p_e_grid_;
|
||||
EDataType* p_e_out_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
@@ -1069,8 +1168,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_out_grid),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(p_e_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
@@ -1114,12 +1213,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t ConvStride = arg.conv_filter_strides_[i];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1131,11 +1230,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1156,10 +1255,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
@@ -1173,7 +1268,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1194,7 +1291,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// If not possible, check access per G
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
|
||||
(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()) &&
|
||||
G % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -1212,7 +1310,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
|
||||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
|
||||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
|
||||
is_same_v<BLayout, ctc::KZYXGC>)
|
||||
is_same_v<BLayout, ctc::KZYXGC> || is_same_v<BLayout, ctc::GKCX> ||
|
||||
is_same_v<BLayout, ctc::GKCYX> || is_same_v<BLayout, ctc::GKCZYX>)
|
||||
|
||||
{
|
||||
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
|
||||
@@ -1270,8 +1369,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
|
||||
@@ -325,9 +325,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
@@ -353,8 +353,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::GKZYXC,
|
||||
BLay>>;
|
||||
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
|
||||
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
@@ -377,9 +385,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
@@ -426,6 +434,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using GKCYXTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKYXCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
@@ -446,6 +461,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseWeightTranspose =
|
||||
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
|
||||
Tuple<GKYXCTransposeDescType>,
|
||||
Tuple<const BDataType*>,
|
||||
Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I0,
|
||||
I1>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
@@ -508,12 +541,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
p_b_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
@@ -559,8 +593,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -570,9 +604,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
|
||||
b_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
|
||||
b_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
@@ -586,25 +629,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -612,6 +643,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
@@ -661,10 +729,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
GKCYXTransposeDescType b_in_transpose_desc_;
|
||||
GKYXCTransposeDescType b_out_transpose_desc_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -702,18 +772,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
@@ -1012,50 +1087,68 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
const index_t b_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
|
||||
arg.b_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
|
||||
GridwiseElementwiseWeightTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<const BDataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(a_grid_size + b_grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.b_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.b_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(arg.p_b_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
make_tuple(p_b_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{});
|
||||
arg.elementwise_block_2_ctile_map_transpose_b_,
|
||||
element_wise::PassThrough{},
|
||||
a_grid_size);
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_out_grid =
|
||||
const EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_in_grid = arg.p_e_grid_;
|
||||
EDataType* p_e_out_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
@@ -1072,8 +1165,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_out_grid),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(p_e_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
@@ -1118,12 +1211,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t ConvStride = arg.conv_filter_strides_[i];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1135,11 +1228,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1171,7 +1264,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
|
||||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
|
||||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
|
||||
is_same_v<BLayout, ctc::KZYXGC>)
|
||||
is_same_v<BLayout, ctc::KZYXGC> || is_same_v<BLayout, ctc::GKCX> ||
|
||||
is_same_v<BLayout, ctc::GKCYX> || is_same_v<BLayout, ctc::GKCZYX>)
|
||||
|
||||
{
|
||||
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
|
||||
@@ -1184,8 +1278,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -59,6 +59,22 @@ constexpr bool is_NGCHW_GKYXC_NGKHW()
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCHW_GKCYX_NGKHW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKCYX> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCHW_NGKHW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
|
||||
}
|
||||
|
||||
// 3d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
|
||||
@@ -84,6 +100,21 @@ constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKCZYX> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCDHW_NGKDHW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
|
||||
{
|
||||
|
||||
@@ -41,13 +41,16 @@ __global__ void
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
template <typename GridwiseElementwiseFunctorA,
|
||||
typename GridwiseElementwiseFunctorB,
|
||||
typename InAGridDescTuple,
|
||||
typename InBGridDescTuple,
|
||||
typename OutAGridDescTuple,
|
||||
typename OutBGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename InADataTypePointerTuple,
|
||||
typename InBDataTypePointerTuple,
|
||||
typename OutADataTypePointerTuple,
|
||||
typename OutBDataTypePointerTuple,
|
||||
typename Block2TileMapA,
|
||||
typename Block2TileMapB,
|
||||
typename ElementwiseOperation>
|
||||
@@ -55,14 +58,14 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
|
||||
kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a,
|
||||
const InBGridDescTuple in_grid_desc_tuple_b,
|
||||
const OutAGridDescTuple out_grid_desc_tuple_a,
|
||||
const OutBGridDescTuple out_grid_desc_tuple_b,
|
||||
const InDataTypePointerTuple p_in_global_tuple_a,
|
||||
const InDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_b,
|
||||
const InADataTypePointerTuple p_in_global_tuple_a,
|
||||
const InBDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutADataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutBDataTypePointerTuple p_out_global_tuple_b,
|
||||
const Block2TileMapA block_2_tile_map_a,
|
||||
const Block2TileMapB block_2_tile_map_b,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
@@ -70,23 +73,23 @@ __global__ void
|
||||
{
|
||||
if(get_block_1d_id() < a_grid_size)
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_tuple_a,
|
||||
p_out_global_tuple_a,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
get_block_1d_id());
|
||||
GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_tuple_a,
|
||||
p_out_global_tuple_a,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
get_block_1d_id());
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_tuple_b,
|
||||
p_out_global_tuple_b,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
get_block_1d_id() - a_grid_size);
|
||||
GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_tuple_b,
|
||||
p_out_global_tuple_b,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
get_block_1d_id() - a_grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,9 +28,10 @@ struct TransformConvNGCHWToNHWGC
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
static auto
|
||||
MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
@@ -55,9 +56,10 @@ struct TransformConvNGCHWToNHWGC
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
static auto
|
||||
MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
@@ -82,9 +84,10 @@ struct TransformConvNGCHWToNHWGC
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
static auto
|
||||
MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
@@ -111,9 +114,10 @@ struct TransformConvNGCHWToNHWGC
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
static auto
|
||||
MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
@@ -140,9 +144,10 @@ struct TransformConvNGCHWToNHWGC
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
static auto
|
||||
MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
@@ -172,9 +177,10 @@ struct TransformConvNGCHWToNHWGC
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
static auto
|
||||
MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
@@ -203,11 +209,185 @@ struct TransformConvNGCHWToNHWGC
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
static auto TransposeStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto
|
||||
MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
if constexpr(device::is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
device::is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
const index_t& G = g_k_c_wis_lengths[I0];
|
||||
const index_t& K = g_k_c_wis_lengths[I1];
|
||||
const index_t& C = g_k_c_wis_lengths[I2];
|
||||
const index_t& X = g_k_c_wis_lengths[I3];
|
||||
|
||||
const index_t& GStride = g_k_c_wis_strides[I0];
|
||||
const index_t& KStride = g_k_c_wis_strides[I1];
|
||||
const index_t& CStride = g_k_c_wis_strides[I2];
|
||||
const index_t& XStride = g_k_c_wis_strides[I3];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride));
|
||||
const auto merged_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto
|
||||
MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_k_c_wis_lengths[I0];
|
||||
const index_t& K = g_k_c_wis_lengths[I1];
|
||||
const index_t& C = g_k_c_wis_lengths[I2];
|
||||
const index_t& X = g_k_c_wis_lengths[I3];
|
||||
|
||||
const index_t& GStride = g_k_c_wis_strides[I0];
|
||||
const index_t KStride = g_k_c_wis_strides[I1];
|
||||
const index_t CStride = 1;
|
||||
const index_t XStride = C;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride));
|
||||
const auto merged_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto
|
||||
MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_k_c_wis_lengths[I0];
|
||||
const index_t& K = g_k_c_wis_lengths[I1];
|
||||
const index_t& C = g_k_c_wis_lengths[I2];
|
||||
const index_t& Y = g_k_c_wis_lengths[I3];
|
||||
const index_t& X = g_k_c_wis_lengths[I4];
|
||||
|
||||
const index_t& GStride = g_k_c_wis_strides[I0];
|
||||
const index_t& KStride = g_k_c_wis_strides[I1];
|
||||
const index_t& CStride = g_k_c_wis_strides[I2];
|
||||
const index_t& YStride = g_k_c_wis_strides[I3];
|
||||
const index_t& XStride = g_k_c_wis_strides[I4];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(G, K, Y, X)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto
|
||||
MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_k_c_wis_lengths[I0];
|
||||
const index_t& K = g_k_c_wis_lengths[I1];
|
||||
const index_t& C = g_k_c_wis_lengths[I2];
|
||||
const index_t& Y = g_k_c_wis_lengths[I3];
|
||||
const index_t& X = g_k_c_wis_lengths[I4];
|
||||
|
||||
const index_t& GStride = g_k_c_wis_strides[I0];
|
||||
const index_t KStride = g_k_c_wis_strides[I1];
|
||||
const index_t CStride = 1;
|
||||
const index_t YStride = X * C;
|
||||
const index_t XStride = C;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(G, K, Y, X)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto
|
||||
MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_k_c_wis_lengths[I0];
|
||||
const index_t& K = g_k_c_wis_lengths[I1];
|
||||
const index_t& C = g_k_c_wis_lengths[I2];
|
||||
const index_t& Z = g_k_c_wis_lengths[I3];
|
||||
const index_t& Y = g_k_c_wis_lengths[I4];
|
||||
const index_t& X = g_k_c_wis_lengths[I5];
|
||||
|
||||
const index_t& GStride = g_k_c_wis_strides[I0];
|
||||
const index_t& KStride = g_k_c_wis_strides[I1];
|
||||
const index_t& CStride = g_k_c_wis_strides[I2];
|
||||
const index_t& ZStride = g_k_c_wis_strides[I3];
|
||||
const index_t& YStride = g_k_c_wis_strides[I4];
|
||||
const index_t& XStride = g_k_c_wis_strides[I5];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(G, K, C, Z, Y, X),
|
||||
make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto
|
||||
MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_k_c_wis_lengths[I0];
|
||||
const index_t& K = g_k_c_wis_lengths[I1];
|
||||
const index_t& C = g_k_c_wis_lengths[I2];
|
||||
const index_t& Z = g_k_c_wis_lengths[I3];
|
||||
const index_t& Y = g_k_c_wis_lengths[I4];
|
||||
const index_t& X = g_k_c_wis_lengths[I5];
|
||||
|
||||
const index_t& GStride = g_k_c_wis_strides[I0];
|
||||
const index_t KStride = g_k_c_wis_strides[I1];
|
||||
const index_t CStride = 1;
|
||||
const index_t ZStride = Y * X * C;
|
||||
const index_t YStride = X * C;
|
||||
const index_t XStride = C;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(G, K, C, Z, Y, X),
|
||||
make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
static auto TransposeInOutStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
|
||||
{
|
||||
if constexpr(device::is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
device::is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
|
||||
const auto G = g_n_c_wis_lengths[I0];
|
||||
@@ -236,6 +416,41 @@ struct TransformConvNGCHWToNHWGC
|
||||
return g_n_c_wis_strides;
|
||||
}
|
||||
}
|
||||
|
||||
static auto
|
||||
TransposeWeiStrides(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
|
||||
{
|
||||
if constexpr(device::is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
device::is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> g_k_c_wis_strides_transposed = g_k_c_wis_strides;
|
||||
const index_t C = g_k_c_wis_lengths[I2];
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const index_t X = g_k_c_wis_lengths[I4];
|
||||
g_k_c_wis_strides_transposed[I2] = 1;
|
||||
g_k_c_wis_strides_transposed[I3] = X * C;
|
||||
g_k_c_wis_strides_transposed[I4] = C;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const index_t Y = g_k_c_wis_lengths[I4];
|
||||
const index_t X = g_k_c_wis_lengths[I5];
|
||||
g_k_c_wis_strides_transposed[I2] = 1;
|
||||
g_k_c_wis_strides_transposed[I3] = Y * X * C;
|
||||
g_k_c_wis_strides_transposed[I4] = X * C;
|
||||
g_k_c_wis_strides_transposed[I5] = C;
|
||||
}
|
||||
return g_k_c_wis_strides_transposed;
|
||||
}
|
||||
else
|
||||
{
|
||||
// transpose not needed
|
||||
return g_k_c_wis_strides;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -71,6 +71,10 @@ using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using GKCX = ck::tensor_layout::convolution::GKCX;
|
||||
using GKCYX = ck::tensor_layout::convolution::GKCYX;
|
||||
using GKCZYX = ck::tensor_layout::convolution::GKCZYX;
|
||||
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -272,20 +272,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
|
||||
// layout NGCHW/GKYXC/NGKHW
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
|
||||
is_same_v<WeiLayout, GKCYX> && is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -294,13 +294,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -311,14 +311,46 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// layout NGCHW/GKYXC/NGKHW
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -326,14 +358,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
|
||||
is_same_v<BComputeType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -91,10 +91,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -107,10 +107,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -122,22 +122,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
|
||||
@@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -91,10 +91,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -107,10 +107,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -122,22 +122,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
|
||||
@@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -91,10 +91,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -107,10 +107,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -122,22 +122,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
|
||||
@@ -252,6 +252,55 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
|
||||
@@ -24,10 +24,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -54,10 +54,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -84,10 +84,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -114,10 +114,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
|
||||
@@ -15,6 +15,10 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
|
||||
# large tensor
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
@@ -27,11 +31,10 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp
|
||||
#mem
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
|
||||
@@ -43,27 +46,24 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instance.cpp
|
||||
#comp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
|
||||
# NGCHW, GKCYX, NGKHW
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp
|
||||
#dl
|
||||
# GNHWC, GKYXC, GNHWK
|
||||
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -39,7 +39,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -39,7 +39,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -1,64 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_int8_comp_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
@@ -1,39 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_int8_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,39 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_int8_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
BF16,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_inst
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_inst
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd3x3>{});
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_insta
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_insta
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd3x3>{});
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
@@ -9,10 +9,10 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd3x3>{});
|
||||
@@ -1,48 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_int8_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_int8_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd3x3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -16,6 +16,7 @@ enum struct ConvLayout
|
||||
GNHWC_GKYXC_GNHWK, // 0
|
||||
NHWGC_GKYXC_NHWGK, // 1
|
||||
NGCHW_GKYXC_NGKHW, // 2
|
||||
NGCHW_GKCYX_NGKHW, // 3
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -52,11 +53,13 @@ static void print_helper_msg()
|
||||
<< " 5: Input bf8, Weight bf8, Output fp8\n"
|
||||
<< " 6: Input fp8, Weight bf8, Output fp8\n"
|
||||
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
|
||||
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
|
||||
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
|
||||
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, "
|
||||
"G, K, Ho, Wo])\n"
|
||||
<< "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n"
|
||||
<< "arg5: verification (0: no, 1: yes)\n"
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -110,6 +113,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
// using GKCX = ck::tensor_layout::convolution::GKXC;
|
||||
using GKCYX = ck::tensor_layout::convolution::GKCYX;
|
||||
// using GKCZYX = ck::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
@@ -302,6 +309,25 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
|
||||
@@ -19,17 +19,20 @@ def init_const_args(args):
|
||||
|
||||
def run_ck_profiler_cmd(cmd):
|
||||
print("ckProfiler command:")
|
||||
print(cmd)
|
||||
cmd_concatenated_str = ""
|
||||
for arg in cmd:
|
||||
cmd_concatenated_str += arg + " "
|
||||
print(cmd_concatenated_str)
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
|
||||
args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 3
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.layout = 3
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data":
|
||||
args.layout = 2
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -65,7 +65,10 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
|
||||
std::tuple<float, NGCHW, GKYXC, NGKHW>,
|
||||
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>,
|
||||
std::tuple<ck::bhalf_t, NGCHW, GKYXC, NGKHW>,
|
||||
std::tuple<int8_t, NGCHW, GKYXC, NGKHW>>;
|
||||
std::tuple<int8_t, NGCHW, GKYXC, NGKHW>,
|
||||
std::tuple<float, NGCHW, GKCYX, NGKHW>,
|
||||
std::tuple<ck::half_t, NGCHW, GKCYX, NGKHW>,
|
||||
std::tuple<ck::bhalf_t, NGCHW, GKCYX, NGKHW>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
|
||||
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
|
||||
|
||||
Reference in New Issue
Block a user