Add support for GKCYX grouped conv weight (#2023)

* Grouped conv bwd weight GKCYX support

* fix and changelog

* fix

* fix

* fixes

* comments

* fix

[ROCm/composable_kernel commit: 2ccf914888]
This commit is contained in:
Bartłomiej Kocot
2025-04-02 23:59:49 +02:00
committed by GitHub
parent 0a607256cf
commit 49565538fe
101 changed files with 1004 additions and 356 deletions

View File

@@ -8,8 +8,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
### Optimized
@@ -26,6 +26,7 @@ None
* Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876)
* DL and DPP kernels are now enabled by default.
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
### Known issues

View File

@@ -30,14 +30,14 @@ List of the device operations for grouped convolution forward in CK:
Table of supported cases by instance factory with XDL instruction:
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|-------|---|---|---|
|bf16 |2D, 3D|2D|1D, 2D, 3D|
|fp16 |2D, 3D|2D|1D, 2D, 3D|
|fp32 |2D, 3D|2D|1D, 2D, 3D|
|int8 |2D, 3D|2D|1D, 3D|
|fp8 |3D|✗|✗|
|bf8 |3D|✗|✗|
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK|
|-------|---|---|---|---|
|bf16 |2D, 3D|2D|2D|1D, 2D, 3D|
|fp16 |2D, 3D|2D|2D|1D, 2D, 3D|
|fp32 |2D, 3D|2D|2D|1D, 2D, 3D|
|int8 |2D, 3D|2D|2D|1D, 3D|
|fp8 |3D|✗|✗|✗|
|bf8 |3D|✗|✗|✗|
Table of supported cases by instance factory with WMMA instruction:

View File

@@ -34,12 +34,12 @@ List of the device operations for grouped convolution backward weight in CK:
Table of supported cases by instance factory with XDL instruction:
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|-------|---|---|---|
|bf16|2D, 3D|2D, 3D|✗|
|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D|
|fp16 |2D, 3D|2D, 3D|1D, 2D, 3D|
|fp32 |2D, 3D|2D, 3D|1D, 2D, 3D|
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK|
|-------|---|---|---|---|
|bf16|2D, 3D|2D, 3D|2D, 3D|✗|
|bf16(fp32 for weight)|2D, 3D|✗|✗|1D, 2D, 3D|
|fp16 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D|
|fp32 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D|
Table of supported cases by instance factory with WMMA instruction:

View File

@@ -218,8 +218,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using EDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation;
@@ -376,6 +376,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
@@ -452,6 +458,28 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
I1,
I1>;
// NPerBlock is used for the first dim which is store dimension
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
// CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so
// it is more flexible to use this dim for store dimension with such scalar
// per vector.
using GridwiseElementwiseWeightTransposeCast =
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
Tuple<GKCYXTransposeDescType>,
Tuple<const AccDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<0, 1>,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
Sequence<1>,
I1,
I0>;
using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
@@ -533,12 +561,15 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
const auto descs =
conv_to_gemm_transformer_v2
@@ -550,7 +581,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
@@ -580,29 +611,21 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_)[I2];
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{
ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ce_grid_desc_m_n_,
GridwiseGemm::CalculateMBlock(GemmM),
GridwiseGemm::CalculateNBlock(GemmN));
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
@@ -618,17 +641,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
}
elementwise_block_2_ctile_map_ =
is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>()
? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0),
e_in_transpose_desc_.GetLength(I1)}
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
ce_grid_desc_m_n_.GetLength(I1)};
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
// Align to 128B
return math::integer_divide_ceil(
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
128;
}
std::size_t GetWorkspaceBTensorSizeBytes() const
@@ -638,14 +679,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
// Align to 128B
return math::integer_divide_ceil(sizeof(AccDataType) *
ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_,
128) *
128;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
// 1. We need to transpose A and B for NGCHW and NGKHW layouts
// 2. If C format is GKCYX then tranpose during second stage.
// If C format is GKYXC then just perform second stage.
// Due to the fact that E workspace is always needed, we
// allocate them as the first part of the workspace.
// [EWorkspace, AWorkspace, BWorkspace]
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
@@ -672,6 +722,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
GKYXCTransposeDescType e_in_transpose_desc_;
GKCYXTransposeDescType e_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
@@ -728,11 +780,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
p_b_grid =
type_convert<const BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
@@ -1373,41 +1425,72 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float avg_time = 0.f;
auto launch_elementwise_kernel = [&]() {
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.ce_elementwise_grid_desc_m_n_) *
arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = {
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation,
I1,
I1>;
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.e_in_transpose_desc_);
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
arg.cde_element_op_,
arg.Conv_G_,
in_out_batch_strides,
in_out_batch_strides);
const auto kernel = kernel_elementwise<GridwiseElementwiseWeightTransposeCast,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
arg.cde_element_op_);
}
else
{
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.ce_elementwise_grid_desc_m_n_) *
arg.Conv_G_;
const auto kernel =
kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
CDEElementwiseOperation,
I1,
I1>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
arg.cde_element_op_,
arg.Conv_G_,
in_out_batch_strides,
in_out_batch_strides);
}
};
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -1417,7 +1500,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
BDataType* p_b_out_grid =
type_convert<BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
@@ -1514,7 +1597,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -1522,7 +1605,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -1597,8 +1680,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
{
@@ -1767,8 +1850,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVector: "
<< TransposeTransferSrcScalarPerVector <<", "
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;

View File

@@ -165,8 +165,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using CDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation;
@@ -301,7 +301,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
static constexpr index_t TransposeTransferSrcScalarPerVectorAligned =
std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector);
@@ -314,13 +314,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
using GridwiseElementwiseTranspose =
using GridwiseInOutTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapTranspose,
element_wise::PassThrough,
BlockSize,
MPerBlock,
@@ -333,6 +339,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
I1,
I0>;
// NPerBlock is used for the first dim which is store dimension
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
Tuple<GKCYXTransposeDescType>,
Tuple<const CDataType*>,
Tuple<CDataType*>,
Block2TileMapTranspose,
element_wise::PassThrough,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
Sequence<1>,
I1,
I0>;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType,
@@ -452,13 +478,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -469,7 +497,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
@@ -487,12 +515,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
@@ -503,8 +526,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
@@ -520,31 +543,33 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes();
// Align to 128B
return math::integer_divide_ceil(
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
128;
}
else
{
@@ -552,6 +577,41 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
// Align to 128B
return math::integer_divide_ceil(
sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) *
128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize();
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
@@ -562,12 +622,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_;
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
GKYXCTransposeDescType e_in_transpose_desc_;
GKCYXTransposeDescType e_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
@@ -621,9 +684,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
CDataType* p_e_grid = arg.p_c_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_e_grid =
type_convert<CDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(CDataType);
}
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -640,8 +713,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
GridwiseElementwiseTranspose,
auto kernel_transpose = kernel_elementwise_dual<GridwiseInOutTranspose,
GridwiseInOutTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
@@ -650,8 +723,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
Block2TileMapTranspose,
Block2TileMapTranspose,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
@@ -698,24 +771,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch<>,
has_main_loop>;
avg_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
const auto clear_workspace = [&]() {
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
hip_check_error(hipMemsetAsync(p_e_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
}
};
avg_time += launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
if(has_main_k0_block_loop)
@@ -726,6 +811,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
launch_kernel(integral_constant<bool, false>{});
}
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_e =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const CDataType* p_e_in_grid = static_cast<const CDataType*>(p_e_grid);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseWeightTranspose,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<const CDataType*>,
ck::Tuple<CDataType*>,
Block2TileMapTranspose,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size_e),
dim3(BlockSize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_in_grid),
make_tuple(arg.p_c_grid_),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
@@ -763,7 +880,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -772,7 +889,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
@@ -810,8 +927,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0)
{
@@ -980,8 +1097,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVectorAligned: "
<< TransposeTransferSrcScalarPerVectorAligned <<", "
<< "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned;

View File

@@ -502,6 +502,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
// NPerBlock is used for the first and second dim which to use
// CDEBlockTransferScalarPerVector_NPerBlock for load and store during
// transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to
// NPerBlock so it is more flexible to use this dim for load store dimension
// with such scalar per vector.
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,

View File

@@ -12,6 +12,15 @@
namespace ck {
namespace tensor_operation {
/*
* Transform Convolution NGCHW to NHWGC. We transform [N, G, C, H, W] tensor
* descriptor to [N * G * C, H * W] (input or output image). The first
* dimension is store dimension, the second one is load dimension. For
* NHWGC to NGCHW load and store are reverted. For weight we transform
* [G, K, C, Y, X] to [G * K * Y * X, C]. First dim is load dimension,
* second dim is store dimension.
*/
template <typename ALayout,
typename BLayout,
typename ELayout,

View File

@@ -431,6 +431,51 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKCYX> &&
is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKYXC> &&
@@ -443,12 +488,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -460,12 +499,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
@@ -610,6 +643,51 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKCZYX> &&
is_same_v<OutLayout, NGKDHW>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKZYXC> &&
@@ -622,12 +700,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -639,12 +711,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32

View File

@@ -149,10 +149,10 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
BF16,
BF16,
@@ -292,10 +292,10 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
BF16,
BF16,
@@ -304,10 +304,22 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKCYX,
NGKHW,
BF16,
BF16,
@@ -329,10 +341,10 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F16,
F16,
@@ -461,10 +473,10 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F16,
F16,
@@ -473,10 +485,22 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKCYX,
NGKHW,
F16,
F16,
@@ -510,6 +534,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKCYX,
NGKHW,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
@@ -611,10 +647,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
BF16,
BF16,
@@ -755,10 +791,10 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
BF16,
BF16,
@@ -767,10 +803,22 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKCZYX,
NGKDHW,
BF16,
BF16,
@@ -792,10 +840,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F16,
F16,
@@ -924,10 +972,10 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F16,
F16,
@@ -936,10 +984,22 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKCZYX,
NGKDHW,
F16,
F16,
@@ -961,6 +1021,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKCZYX,
NGKDHW,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,

View File

@@ -1,47 +1,53 @@
# ONLY XDL_AND_DL_KERNELS
set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp
xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
)
if(DL_KERNELS)

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKCYX,
NGKHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
2,
NGCHW,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
BF16,
BF16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
BF16,
BF16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKCYX,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
2,
NGCHW,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F16,
F16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F16,
F16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
BF16,
BF16,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
1,
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
4,

View File

@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F16,
F16,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
1,
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
4,

View File

@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
F32,
F32,
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
1,
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
NGCHW,
GKYXC,
GKCYX,
NGKHW,
ConvBwdWeightDefault,
4,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<2,
NGCHW,
GKYXC,
NGKHW,
ConvBwdWeightDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,43 +1,49 @@
# XDL_DL_WMMA_KERNELS
set(GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp
xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp
xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp
)
if(DL_KERNELS)
@@ -62,7 +68,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
endif()
add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT})

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKCZYX,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
3,
NGCDHW,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
BF16,
BF16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
BF16,
BF16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKCZYX,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
3,
NGCDHW,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F16,
F16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F16,
F16,
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,

View File

@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
BF16,
BF16,
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
1,
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
4,

View File

@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F16,
F16,
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
1,
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
4,

View File

@@ -10,10 +10,10 @@ namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances(
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
F32,
F32,
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
1,
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3,
NGCDHW,
GKZYXC,
GKCZYX,
NGKDHW,
ConvBwdWeightDefault,
4,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
@@ -17,6 +17,7 @@ enum struct ConvLayout
GNHWC_GKYXC_GNHWK, // 1
NHWGC_GKYXC_NHWGK, // 2
NGCHW_GKYXC_NGKHW, // 3
NGCHW_GKCYX_NGKHW, // 4
};
enum struct ConvDataType
@@ -49,6 +50,8 @@ static void print_helper_msg()
"Ho, Wo, G, K]\n"
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
<< " 4: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, "
"G, K, Ho, Wo]\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
@@ -199,6 +202,21 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
@@ -262,6 +280,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;

View File

@@ -29,8 +29,9 @@ def run_ck_profiler_cmd(cmd):
def parse_layouts(args):
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_fwd" or \
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 4
elif args.ck_profier_op == "grouped_conv_fwd" or \
args.ck_profier_op == "grouped_conv_bwd_data":
args.layout = 3
else:

Some files were not shown because too many files have changed in this diff Show More