mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Add support for GKCYX grouped conv weight (#2023)
* Grouped conv bwd weight GKCYX support
* fix and changelog
* fix
* fix
* fixes
* comments
* fix
[ROCm/composable_kernel commit: 2ccf914888]
This commit is contained in:
@@ -8,8 +8,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
|
||||
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
|
||||
* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
|
||||
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
|
||||
* Added support for Stream-K version of mixed fp8/bf16 GEMM
|
||||
|
||||
### Optimized
|
||||
@@ -26,6 +26,7 @@ None
|
||||
* Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876)
|
||||
* DL and DPP kernels are now enabled by default.
|
||||
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
|
||||
|
||||
### Known issues
|
||||
|
||||
@@ -30,14 +30,14 @@ List of the device operations for grouped convolution forward in CK:
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16 |2D, 3D|2D|1D, 2D, 3D|
|
||||
|fp16 |2D, 3D|2D|1D, 2D, 3D|
|
||||
|fp32 |2D, 3D|2D|1D, 2D, 3D|
|
||||
|int8 |2D, 3D|2D|1D, 3D|
|
||||
|fp8 |3D|✗|✗|
|
||||
|bf8 |3D|✗|✗|
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|---|
|
||||
|bf16 |2D, 3D|2D|2D|1D, 2D, 3D|
|
||||
|fp16 |2D, 3D|2D|2D|1D, 2D, 3D|
|
||||
|fp32 |2D, 3D|2D|2D|1D, 2D, 3D|
|
||||
|int8 |2D, 3D|2D|2D|1D, 3D|
|
||||
|fp8 |3D|✗|✗|✗|
|
||||
|bf8 |3D|✗|✗|✗|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction:
|
||||
|
||||
|
||||
@@ -34,12 +34,12 @@ List of the device operations for grouped convolution backward weight in CK:
|
||||
|
||||
Table of supported cases by instance factory with XDL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16|2D, 3D|2D, 3D|✗|
|
||||
|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D|
|
||||
|fp16 |2D, 3D|2D, 3D|1D, 2D, 3D|
|
||||
|fp32 |2D, 3D|2D, 3D|1D, 2D, 3D|
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|---|
|
||||
|bf16|2D, 3D|2D, 3D|2D, 3D|✗|
|
||||
|bf16(fp32 for weight)|2D, 3D|✗|✗|1D, 2D, 3D|
|
||||
|fp16 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D|
|
||||
|fp32 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction:
|
||||
|
||||
|
||||
@@ -218,8 +218,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using EDataType = WeiDataType;
|
||||
|
||||
// If NGCHW then ADataType must be equal to BDataType
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
is_same_v<ADataType, BDataType>);
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
@@ -376,6 +376,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKCYXTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKYXCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
@@ -452,6 +458,28 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
I1,
|
||||
I1>;
|
||||
// NPerBlock is used for the first dim which is store dimension
|
||||
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
|
||||
// CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so
|
||||
// it is more flexible to use this dim for store dimension with such scalar
|
||||
// per vector.
|
||||
using GridwiseElementwiseWeightTransposeCast =
|
||||
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
|
||||
Tuple<GKCYXTransposeDescType>,
|
||||
Tuple<const AccDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<0, 1>,
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
Sequence<1>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
@@ -533,12 +561,15 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
@@ -550,7 +581,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
@@ -580,29 +611,21 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads,
|
||||
k_batch_)[I2];
|
||||
|
||||
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{
|
||||
ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
|
||||
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ce_grid_desc_m_n_,
|
||||
GridwiseGemm::CalculateMBlock(GemmM),
|
||||
GridwiseGemm::CalculateNBlock(GemmN));
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
@@ -618,17 +641,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
|
||||
elementwise_block_2_ctile_map_ =
|
||||
is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>()
|
||||
? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0),
|
||||
e_in_transpose_desc_.GetLength(I1)}
|
||||
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
|
||||
ce_grid_desc_m_n_.GetLength(I1)};
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(
|
||||
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
|
||||
128;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
@@ -638,14 +679,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(sizeof(AccDataType) *
|
||||
ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_,
|
||||
128) *
|
||||
128;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
// 1. We need to transpose A and B for NGCHW and NGKHW layouts
|
||||
// 2. If C format is GKCYX then tranpose during second stage.
|
||||
// If C format is GKYXC then just perform second stage.
|
||||
// Due to the fact that E workspace is always needed, we
|
||||
// allocate them as the first part of the workspace.
|
||||
// [EWorkspace, AWorkspace, BWorkspace]
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
@@ -672,6 +722,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
GKYXCTransposeDescType e_in_transpose_desc_;
|
||||
GKCYXTransposeDescType e_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
@@ -728,11 +780,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
|
||||
p_b_grid =
|
||||
type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
@@ -1373,41 +1425,72 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
float avg_time = 0.f;
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
arg.ce_elementwise_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
|
||||
std::array<index_t, I1> in_out_batch_strides = {
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
|
||||
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
I1,
|
||||
I1>;
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_,
|
||||
arg.Conv_G_,
|
||||
in_out_batch_strides,
|
||||
in_out_batch_strides);
|
||||
const auto kernel = kernel_elementwise<GridwiseElementwiseWeightTransposeCast,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
|
||||
arg.ce_elementwise_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<CElementwiseGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_,
|
||||
arg.Conv_G_,
|
||||
in_out_batch_strides,
|
||||
in_out_batch_strides);
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_a =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -1417,7 +1500,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
|
||||
arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType);
|
||||
BDataType* p_b_out_grid =
|
||||
type_convert<BDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
@@ -1514,7 +1597,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1522,7 +1605,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1597,8 +1680,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
|
||||
{
|
||||
@@ -1767,8 +1850,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< NumGroupsToMerge;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
str << ", TransposeTransferSrcScalarPerVector: "
|
||||
<< TransposeTransferSrcScalarPerVector <<", "
|
||||
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
|
||||
|
||||
@@ -165,8 +165,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
using CDataType = WeiDataType;
|
||||
|
||||
// If NGCHW then ADataType must be equal to BDataType
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
||||
is_same_v<ADataType, BDataType>);
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
@@ -301,7 +301,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
static constexpr index_t TransposeTransferSrcScalarPerVectorAligned =
|
||||
std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector);
|
||||
@@ -314,13 +314,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKCYXTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using GKYXCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
using GridwiseInOutTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -333,6 +339,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
// NPerBlock is used for the first dim which is store dimension
|
||||
// (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
|
||||
using GridwiseElementwiseWeightTranspose =
|
||||
GridwiseElementwise<Tuple<GKYXCTransposeDescType>,
|
||||
Tuple<GKCYXTransposeDescType>,
|
||||
Tuple<const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
Sequence<1>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -452,13 +478,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -469,7 +497,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
@@ -487,12 +515,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
@@ -503,8 +526,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
@@ -520,31 +543,33 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
|
||||
e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
|
||||
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{
|
||||
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes();
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(
|
||||
sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
|
||||
128;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -552,6 +577,41 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
// Align to 128B
|
||||
return math::integer_divide_ceil(
|
||||
sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) *
|
||||
128;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
|
||||
GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
@@ -562,12 +622,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_;
|
||||
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
|
||||
GKYXCTransposeDescType e_in_transpose_desc_;
|
||||
GKCYXTransposeDescType e_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
|
||||
@@ -621,9 +684,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
CDataType* p_e_grid = arg.p_c_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
p_e_grid =
|
||||
type_convert<CDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(CDataType);
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_a =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -640,8 +713,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
GridwiseElementwiseTranspose,
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseInOutTranspose,
|
||||
GridwiseInOutTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
@@ -650,8 +723,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapTranspose,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
@@ -698,24 +771,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
has_main_loop>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
const auto clear_workspace = [&]() {
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(p_e_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
@@ -726,6 +811,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
const index_t grid_size_e =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const CDataType* p_e_in_grid = static_cast<const CDataType*>(p_e_grid);
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseWeightTranspose,
|
||||
ck::Tuple<GKYXCTransposeDescType>,
|
||||
ck::Tuple<GKCYXTransposeDescType>,
|
||||
ck::Tuple<const CDataType*>,
|
||||
ck::Tuple<CDataType*>,
|
||||
Block2TileMapTranspose,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size_e),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(arg.p_c_grid_),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
@@ -763,7 +880,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -772,7 +889,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -810,8 +927,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0)
|
||||
{
|
||||
@@ -980,8 +1097,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
if constexpr(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
|
||||
str << ", TransposeTransferSrcScalarPerVectorAligned: "
|
||||
<< TransposeTransferSrcScalarPerVectorAligned <<", "
|
||||
<< "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned;
|
||||
|
||||
@@ -502,6 +502,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
// NPerBlock is used for the first and second dim which to use
|
||||
// CDEBlockTransferScalarPerVector_NPerBlock for load and store during
|
||||
// transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to
|
||||
// NPerBlock so it is more flexible to use this dim for load store dimension
|
||||
// with such scalar per vector.
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
|
||||
@@ -12,6 +12,15 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
/*
|
||||
* Transform Convolution NGCHW to NHWGC. We transform [N, G, C, H, W] tensor
|
||||
* descriptor to [N * G * C, H * W] (input or output image). The first
|
||||
* dimension is store dimension, the second one is load dimension. For
|
||||
* NHWGC to NGCHW load and store are reverted. For weight we transform
|
||||
* [G, K, C, Y, X] to [G * K * Y * X, C]. First dim is load dimension,
|
||||
* second dim is store dimension.
|
||||
*/
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
|
||||
@@ -431,6 +431,51 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKCYX> &&
|
||||
is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
|
||||
is_same_v<ComputeTypeB, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKYXC> &&
|
||||
@@ -443,12 +488,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -460,12 +499,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -610,6 +643,51 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKCZYX> &&
|
||||
is_same_v<OutLayout, NGKDHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
|
||||
is_same_v<ComputeTypeB, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
@@ -622,12 +700,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -639,12 +711,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
|
||||
@@ -149,10 +149,10 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -292,10 +292,10 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -304,10 +304,22 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -329,10 +341,10 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -461,10 +473,10 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -473,10 +485,22 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -510,6 +534,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -611,10 +647,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -755,10 +791,10 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -767,10 +803,22 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -792,10 +840,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -924,10 +972,10 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -936,10 +984,22 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -961,6 +1021,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
|
||||
@@ -1,47 +1,53 @@
|
||||
# ONLY XDL_AND_DL_KERNELS
|
||||
set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp
|
||||
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp
|
||||
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp
|
||||
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp
|
||||
xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp
|
||||
|
||||
xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
|
||||
xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_p
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pi
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
1,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
4,
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
1,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
4,
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
1,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
4,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,43 +1,49 @@
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp
|
||||
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
|
||||
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp
|
||||
|
||||
xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
|
||||
xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp
|
||||
xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp
|
||||
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp
|
||||
xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
@@ -62,7 +68,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf1
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
1,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
4,
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
1,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
4,
|
||||
@@ -10,10 +10,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
F32,
|
||||
F32,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
1,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
GKCZYX,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
4,
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
@@ -17,6 +17,7 @@ enum struct ConvLayout
|
||||
GNHWC_GKYXC_GNHWK, // 1
|
||||
NHWGC_GKYXC_NHWGK, // 2
|
||||
NGCHW_GKYXC_NGKHW, // 3
|
||||
NGCHW_GKCYX_NGKHW, // 4
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -49,6 +50,8 @@ static void print_helper_msg()
|
||||
"Ho, Wo, G, K]\n"
|
||||
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< " 4: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -199,6 +202,21 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
@@ -262,6 +280,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(
|
||||
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
|
||||
@@ -29,8 +29,9 @@ def run_ck_profiler_cmd(cmd):
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
|
||||
args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd" or \
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 4
|
||||
elif args.ck_profier_op == "grouped_conv_fwd" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data":
|
||||
args.layout = 3
|
||||
else:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user