mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Add support for NGCHW in grouped conv fwd (#1499)
* Support NGCHW in grouped conv fwd * Remove not needed variable * Fixes
This commit is contained in:
@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
|
||||
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
|
||||
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
|
||||
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
|
||||
@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
|
||||
@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
Reference in New Issue
Block a user