WMMA grouped conv fwd large tensor bias bnorm clamp (#3595)

* Added bias_bnorm_clamp for WMMA conv fwd large tensor.

Following operations are added for FP16/BF16 data type and NHWGCxGKYXC layout.
- grouped_conv2d_fwd_bias_bnorm_clamp
- grouped_conv3d_fwd_bias_bnorm_clamp

* changed strategy to handle GemmArgs array

* Adding generic instance

* fixed last nits from reviewers and copilot

[ROCm/composable_kernel commit: 2e08a7e5ab]
This commit is contained in:
Wojciech Laskowski
2026-01-23 12:20:00 +01:00
committed by GitHub
parent ee595ee58a
commit a0445aff5f
8 changed files with 252 additions and 0 deletions

View File

@@ -297,6 +297,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
@@ -306,6 +309,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
}
@@ -322,6 +328,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
@@ -331,6 +340,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
}

View File

@@ -24,6 +24,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
@@ -38,6 +53,21 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_n
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
@@ -56,6 +86,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
@@ -70,6 +115,21 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_n
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif
} // namespace instance