Add support for GKCYX grouped conv fwd (#2015)

* Add support for GKCYX grouped conv fwd

* fixes

* fix

* changelog

* Fixes

[ROCm/composable_kernel commit: 54c81a1fcf]
This commit is contained in:
Bartłomiej Kocot
2025-03-26 21:13:38 +01:00
committed by GitHub
parent ba16351a03
commit f967fd7296
39 changed files with 1005 additions and 570 deletions

View File

@@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
### Optimized

View File

@@ -496,11 +496,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
input_right_pads_{input_right_pads}
{
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {

View File

@@ -534,11 +534,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
const auto descs =
conv_to_gemm_transformer_v2
@@ -1425,11 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
GridwiseElementwiseTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,

View File

@@ -453,11 +453,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
const auto descs =
conv_to_gemm_transformer
@@ -641,11 +641,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
GridwiseElementwiseTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,

View File

@@ -314,8 +314,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
// NGCHW is not supported for multiAB
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()) ||
!(isMultiA || isMultiB));
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
@@ -355,11 +355,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -373,8 +371,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay>
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
@@ -387,11 +391,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
@@ -491,6 +493,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
@@ -511,6 +520,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
I1,
I0>;
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
Tuple<GKYXCTransposeDescType>,
Tuple<const BDataType*>,
Tuple<BDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<1>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
@@ -558,14 +585,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
@@ -744,8 +772,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -755,6 +783,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
b_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
b_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
@@ -764,6 +799,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
@@ -771,25 +808,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceATensorSizeBytes() const
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
}
else
{
@@ -797,6 +822,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
@@ -849,10 +911,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
GKCYXTransposeDescType b_in_transpose_desc_;
GKYXCTransposeDescType b_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
@@ -942,14 +1006,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
{
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
@@ -978,8 +1056,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid, // Pass just A descriptor instead of tuple
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
@@ -1009,50 +1087,71 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t b_grid_size =
(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_)
: 0; // Dont run transpose B if not needed
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_as_grid_.At(I0)),
make_tuple(arg.p_bs_grid_.At(I0)),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size);
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
@@ -1069,8 +1168,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
@@ -1114,12 +1213,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1131,11 +1230,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1156,10 +1255,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
}
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
}
if constexpr(NumGroupsToMerge > 1)
@@ -1173,7 +1268,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>() ||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()))
{
return false;
}
@@ -1194,7 +1291,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()) &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;
@@ -1212,7 +1310,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
is_same_v<BLayout, ctc::KZYXGC> || is_same_v<BLayout, ctc::GKCX> ||
is_same_v<BLayout, ctc::GKCYX> || is_same_v<BLayout, ctc::GKCZYX>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
@@ -1270,8 +1369,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{

View File

@@ -325,9 +325,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
@@ -353,8 +353,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static auto
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
@@ -377,9 +385,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
@@ -426,6 +434,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
@@ -446,6 +461,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
I1,
I0>;
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
Tuple<GKYXCTransposeDescType>,
Tuple<const BDataType*>,
Tuple<BDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<1>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
@@ -508,12 +541,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
@@ -559,8 +593,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -570,9 +604,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
b_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
b_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
@@ -586,25 +629,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
std::size_t GetWorkspaceATensorSizeBytes() const
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
}
else
{
@@ -612,6 +643,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
@@ -661,10 +729,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// block-to-e-tile map
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
GKCYXTransposeDescType b_in_transpose_desc_;
GKYXCTransposeDescType b_out_transpose_desc_;
};
// Invoker
@@ -702,18 +772,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
typename GridwiseGemm::Argument gemm_arg{
p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
@@ -1012,50 +1087,68 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t b_grid_size =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size);
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
@@ -1072,8 +1165,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
@@ -1118,12 +1211,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1135,11 +1228,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1171,7 +1264,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
is_same_v<BLayout, ctc::KZYXGC> || is_same_v<BLayout, ctc::GKCX> ||
is_same_v<BLayout, ctc::GKCYX> || is_same_v<BLayout, ctc::GKCZYX>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
@@ -1184,8 +1278,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -59,6 +59,22 @@ constexpr bool is_NGCHW_GKYXC_NGKHW()
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_GKCYX_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKCYX> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
// 3d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
@@ -84,6 +100,21 @@ constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKCZYX> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCDHW_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
{

View File

@@ -41,13 +41,16 @@ __global__ void
elementwise_op);
}
template <typename GridwiseElementwiseFunctor,
template <typename GridwiseElementwiseFunctorA,
typename GridwiseElementwiseFunctorB,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename InADataTypePointerTuple,
typename InBDataTypePointerTuple,
typename OutADataTypePointerTuple,
typename OutBDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation>
@@ -55,14 +58,14 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InDataTypePointerTuple p_in_global_tuple_a,
const InDataTypePointerTuple p_in_global_tuple_b,
const OutDataTypePointerTuple p_out_global_tuple_a,
const OutDataTypePointerTuple p_out_global_tuple_b,
const InADataTypePointerTuple p_in_global_tuple_a,
const InBDataTypePointerTuple p_in_global_tuple_b,
const OutADataTypePointerTuple p_out_global_tuple_a,
const OutBDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
@@ -70,23 +73,23 @@ __global__ void
{
if(get_block_1d_id() < a_grid_size)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_tuple_a,
p_out_global_tuple_a,
block_2_tile_map_a,
elementwise_op,
get_block_1d_id());
GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_tuple_a,
p_out_global_tuple_a,
block_2_tile_map_a,
elementwise_op,
get_block_1d_id());
}
else
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_tuple_b,
p_out_global_tuple_b,
block_2_tile_map_b,
elementwise_op,
get_block_1d_id() - a_grid_size);
GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_tuple_b,
p_out_global_tuple_b,
block_2_tile_map_b,
elementwise_op,
get_block_1d_id() - a_grid_size);
}
}

View File

@@ -28,9 +28,10 @@ struct TransformConvNGCHWToNHWGC
static constexpr auto I5 = Number<5>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
const index_t split_n_size = 1)
static auto
MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
const index_t split_n_size = 1)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
@@ -55,9 +56,10 @@ struct TransformConvNGCHWToNHWGC
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
const index_t split_n_size = 1)
static auto
MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
const index_t split_n_size = 1)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
@@ -82,9 +84,10 @@ struct TransformConvNGCHWToNHWGC
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
const index_t split_n_size = 1)
static auto
MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
const index_t split_n_size = 1)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
@@ -111,9 +114,10 @@ struct TransformConvNGCHWToNHWGC
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
const index_t split_n_size = 1)
static auto
MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
const index_t split_n_size = 1)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
@@ -140,9 +144,10 @@ struct TransformConvNGCHWToNHWGC
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
const index_t split_n_size = 1)
static auto
MakeNGCHWTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
const index_t split_n_size = 1)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
@@ -172,9 +177,10 @@ struct TransformConvNGCHWToNHWGC
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
const index_t split_n_size = 1)
static auto
MakeNHWGCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
const index_t split_n_size = 1)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
@@ -203,11 +209,185 @@ struct TransformConvNGCHWToNHWGC
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
static auto TransposeStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto
MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
if constexpr(device::is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
device::is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
const index_t& G = g_k_c_wis_lengths[I0];
const index_t& K = g_k_c_wis_lengths[I1];
const index_t& C = g_k_c_wis_lengths[I2];
const index_t& X = g_k_c_wis_lengths[I3];
const index_t& GStride = g_k_c_wis_strides[I0];
const index_t& KStride = g_k_c_wis_strides[I1];
const index_t& CStride = g_k_c_wis_strides[I2];
const index_t& XStride = g_k_c_wis_strides[I3];
const auto desc = make_naive_tensor_descriptor(
make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride));
const auto merged_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto
MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
const index_t& G = g_k_c_wis_lengths[I0];
const index_t& K = g_k_c_wis_lengths[I1];
const index_t& C = g_k_c_wis_lengths[I2];
const index_t& X = g_k_c_wis_lengths[I3];
const index_t& GStride = g_k_c_wis_strides[I0];
const index_t KStride = g_k_c_wis_strides[I1];
const index_t CStride = 1;
const index_t XStride = C;
const auto desc = make_naive_tensor_descriptor(
make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride));
const auto merged_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto
MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
const index_t& G = g_k_c_wis_lengths[I0];
const index_t& K = g_k_c_wis_lengths[I1];
const index_t& C = g_k_c_wis_lengths[I2];
const index_t& Y = g_k_c_wis_lengths[I3];
const index_t& X = g_k_c_wis_lengths[I4];
const index_t& GStride = g_k_c_wis_strides[I0];
const index_t& KStride = g_k_c_wis_strides[I1];
const index_t& CStride = g_k_c_wis_strides[I2];
const index_t& YStride = g_k_c_wis_strides[I3];
const index_t& XStride = g_k_c_wis_strides[I4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(G, K, Y, X)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto
MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
const index_t& G = g_k_c_wis_lengths[I0];
const index_t& K = g_k_c_wis_lengths[I1];
const index_t& C = g_k_c_wis_lengths[I2];
const index_t& Y = g_k_c_wis_lengths[I3];
const index_t& X = g_k_c_wis_lengths[I4];
const index_t& GStride = g_k_c_wis_strides[I0];
const index_t KStride = g_k_c_wis_strides[I1];
const index_t CStride = 1;
const index_t YStride = X * C;
const index_t XStride = C;
const auto desc = make_naive_tensor_descriptor(
make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(G, K, Y, X)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto
MakeGKCYXTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
const index_t& G = g_k_c_wis_lengths[I0];
const index_t& K = g_k_c_wis_lengths[I1];
const index_t& C = g_k_c_wis_lengths[I2];
const index_t& Z = g_k_c_wis_lengths[I3];
const index_t& Y = g_k_c_wis_lengths[I4];
const index_t& X = g_k_c_wis_lengths[I5];
const index_t& GStride = g_k_c_wis_strides[I0];
const index_t& KStride = g_k_c_wis_strides[I1];
const index_t& CStride = g_k_c_wis_strides[I2];
const index_t& ZStride = g_k_c_wis_strides[I3];
const index_t& YStride = g_k_c_wis_strides[I4];
const index_t& XStride = g_k_c_wis_strides[I5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(G, K, C, Z, Y, X),
make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto
MakeGKYXCTransposeDesc(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
const index_t& G = g_k_c_wis_lengths[I0];
const index_t& K = g_k_c_wis_lengths[I1];
const index_t& C = g_k_c_wis_lengths[I2];
const index_t& Z = g_k_c_wis_lengths[I3];
const index_t& Y = g_k_c_wis_lengths[I4];
const index_t& X = g_k_c_wis_lengths[I5];
const index_t& GStride = g_k_c_wis_strides[I0];
const index_t KStride = g_k_c_wis_strides[I1];
const index_t CStride = 1;
const index_t ZStride = Y * X * C;
const index_t YStride = X * C;
const index_t XStride = C;
const auto desc = make_naive_tensor_descriptor(
make_tuple(G, K, C, Z, Y, X),
make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
static auto TransposeInOutStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
{
if constexpr(device::is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
device::is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
const auto G = g_n_c_wis_lengths[I0];
@@ -236,6 +416,41 @@ struct TransformConvNGCHWToNHWGC
return g_n_c_wis_strides;
}
}
static auto
TransposeWeiStrides(const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
{
if constexpr(device::is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
device::is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
std::array<index_t, NDimSpatial + 3> g_k_c_wis_strides_transposed = g_k_c_wis_strides;
const index_t C = g_k_c_wis_lengths[I2];
if constexpr(NDimSpatial == 2)
{
const index_t X = g_k_c_wis_lengths[I4];
g_k_c_wis_strides_transposed[I2] = 1;
g_k_c_wis_strides_transposed[I3] = X * C;
g_k_c_wis_strides_transposed[I4] = C;
}
else if constexpr(NDimSpatial == 3)
{
const index_t Y = g_k_c_wis_lengths[I4];
const index_t X = g_k_c_wis_lengths[I5];
g_k_c_wis_strides_transposed[I2] = 1;
g_k_c_wis_strides_transposed[I3] = Y * X * C;
g_k_c_wis_strides_transposed[I4] = X * C;
g_k_c_wis_strides_transposed[I5] = C;
}
return g_k_c_wis_strides_transposed;
}
else
{
// transpose not needed
return g_k_c_wis_strides;
}
}
};
} // namespace tensor_operation

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -71,6 +71,10 @@ using GKXC = ck::tensor_layout::convolution::GKXC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GKCX = ck::tensor_layout::convolution::GKCX;
using GKCYX = ck::tensor_layout::convolution::GKCYX;
using GKCZYX = ck::tensor_layout::convolution::GKCZYX;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -272,20 +272,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
// layout NGCHW/GKYXC/NGKHW
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
is_same_v<WeiLayout, GKCYX> && is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances(
op_ptrs);
}
#endif
@@ -294,13 +294,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances(
op_ptrs);
}
#endif
@@ -311,14 +311,46 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances(
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances(
op_ptrs);
}
#endif
}
// layout NGCHW/GKYXC/NGKHW
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
@@ -326,14 +358,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances(
op_ptrs);
}
#endif
}

View File

@@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances(
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -91,10 +91,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -107,10 +107,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -122,22 +122,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(

View File

@@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -91,10 +91,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -107,10 +107,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -122,22 +122,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(

View File

@@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -91,10 +91,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -107,10 +107,10 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -122,22 +122,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(

View File

@@ -252,6 +252,55 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances(
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKCYX/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(

View File

@@ -24,10 +24,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -54,10 +54,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -84,10 +84,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -114,10 +114,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
int8_t,

View File

@@ -15,6 +15,10 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
# large tensor
# NHWGC, GKYXC, NHWGK
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
@@ -27,11 +31,10 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp
#mem
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
@@ -43,27 +46,24 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instance.cpp
#comp
# NHWGC, GKYXC, NHWGK
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp
#dl
# GNHWC, GKYXC, GNHWK
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
@@ -10,10 +10,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -39,7 +39,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
@@ -10,10 +10,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -39,7 +39,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
instances,
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});

View File

@@ -1,64 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_int8_comp_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
if(ck::get_device_name() != "gfx950")
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
if(ck::get_device_name() == "gfx950")
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault,

View File

@@ -1,39 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_int8_mem_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,39 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_int8_mem_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
BF16,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_inst
instances,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_inst
instances,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F16,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_insta
instances,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_insta
instances,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
@@ -9,10 +9,10 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
F32,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
instances,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
instances,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
NGCHW,
GKYXC,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});

View File

@@ -1,48 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_int8_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_int8_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -16,6 +16,7 @@ enum struct ConvLayout
GNHWC_GKYXC_GNHWK, // 0
NHWGC_GKYXC_NHWGK, // 1
NGCHW_GKYXC_NGKHW, // 2
NGCHW_GKCYX_NGKHW, // 3
};
enum struct ConvDataType
@@ -52,11 +53,13 @@ static void print_helper_msg()
<< " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8\n"
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
"G, K, Ho, Wo]\n"
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, "
"G, K, Ho, Wo])\n"
<< "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
@@ -110,6 +113,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
// using GKCX = ck::tensor_layout::convolution::GKXC;
using GKCYX = ck::tensor_layout::convolution::GKCYX;
// using GKCZYX = ck::tensor_layout::convolution::GKZYXC;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
@@ -302,6 +309,25 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{

View File

@@ -19,17 +19,20 @@ def init_const_args(args):
def run_ck_profiler_cmd(cmd):
print("ckProfiler command:")
print(cmd)
cmd_concatenated_str = ""
for arg in cmd:
cmd_concatenated_str += arg + " "
print(cmd_concatenated_str)
subprocess.run(cmd)
def parse_layouts(args):
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 3
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_fwd":
args.layout = 3
elif args.ck_profier_op == "grouped_conv_bwd_data":
args.layout = 2
else:
print('Not supported layout for this op')

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -65,7 +65,10 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std::tuple<float, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::bhalf_t, NGCHW, GKYXC, NGKHW>,
std::tuple<int8_t, NGCHW, GKYXC, NGKHW>>;
std::tuple<int8_t, NGCHW, GKYXC, NGKHW>,
std::tuple<float, NGCHW, GKCYX, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKCYX, NGKHW>,
std::tuple<ck::bhalf_t, NGCHW, GKCYX, NGKHW>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,