Grouped conv bias clamp fp32/fp16 support (#2366)

[ROCm/composable_kernel commit: 663992e99b]
This commit is contained in:
Bartłomiej Kocot
2025-06-20 11:41:04 +02:00
committed by GitHub
parent 7c10189a27
commit 9e27236fb7
71 changed files with 4733 additions and 22 deletions

View File

@@ -99,6 +99,52 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
@@ -127,6 +173,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_XDL

View File

@@ -236,6 +236,434 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -98,6 +98,50 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
@@ -126,6 +170,46 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_XDL

View File

@@ -236,6 +236,434 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation