mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 08:18:26 +00:00
WMMA grouped conv fwd large tensor bias bnorm clamp (#3595)
* Added bias_bnorm_clamp for WMMA conv fwd large tensor. Following operations are added for FP16/BF16 data type and NHWGCxGKYXC layout. - grouped_conv2d_fwd_bias_bnorm_clamp - grouped_conv3d_fwd_bias_bnorm_clamp * changed strategy to handle GemmArgs array * Adding generic instance * fixed last nits from reviewers and copilot
This commit is contained in:
committed by
GitHub
parent
81ee19bd2c
commit
2e08a7e5ab
@@ -297,6 +297,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
|||||||
{
|
{
|
||||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
|
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||||
op_ptrs);
|
op_ptrs);
|
||||||
|
|
||||||
|
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||||
|
op_ptrs);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef CK_ENABLE_FP16
|
#ifdef CK_ENABLE_FP16
|
||||||
@@ -306,6 +309,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
|||||||
{
|
{
|
||||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
|
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||||
op_ptrs);
|
op_ptrs);
|
||||||
|
|
||||||
|
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||||
|
op_ptrs);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -322,6 +328,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
|||||||
{
|
{
|
||||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||||
op_ptrs);
|
op_ptrs);
|
||||||
|
|
||||||
|
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||||
|
op_ptrs);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef CK_ENABLE_FP16
|
#ifdef CK_ENABLE_FP16
|
||||||
@@ -331,6 +340,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
|||||||
{
|
{
|
||||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||||
op_ptrs);
|
op_ptrs);
|
||||||
|
|
||||||
|
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||||
|
op_ptrs);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw
|
|||||||
PassThrough,
|
PassThrough,
|
||||||
BiasNormalizeInInferClamp>>>& instances);
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
|
|
||||||
|
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||||
|
std::vector<
|
||||||
|
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||||
|
NHWGC,
|
||||||
|
GKYXC,
|
||||||
|
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||||
|
NHWGK,
|
||||||
|
BF16,
|
||||||
|
BF16,
|
||||||
|
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||||
|
BF16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
|
|
||||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||||
std::vector<std::unique_ptr<
|
std::vector<std::unique_ptr<
|
||||||
DeviceGroupedConvFwdMultipleABD<3,
|
DeviceGroupedConvFwdMultipleABD<3,
|
||||||
@@ -38,6 +53,21 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_n
|
|||||||
PassThrough,
|
PassThrough,
|
||||||
PassThrough,
|
PassThrough,
|
||||||
BiasNormalizeInInferClamp>>>& instances);
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
|
|
||||||
|
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||||
|
std::vector<std::unique_ptr<
|
||||||
|
DeviceGroupedConvFwdMultipleABD<3,
|
||||||
|
NDHWGC,
|
||||||
|
GKZYXC,
|
||||||
|
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||||
|
NDHWGK,
|
||||||
|
BF16,
|
||||||
|
BF16,
|
||||||
|
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||||
|
BF16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef CK_ENABLE_FP16
|
#ifdef CK_ENABLE_FP16
|
||||||
@@ -56,6 +86,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw
|
|||||||
PassThrough,
|
PassThrough,
|
||||||
BiasNormalizeInInferClamp>>>& instances);
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
|
|
||||||
|
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||||
|
std::vector<
|
||||||
|
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||||
|
NHWGC,
|
||||||
|
GKYXC,
|
||||||
|
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||||
|
NHWGK,
|
||||||
|
F16,
|
||||||
|
F16,
|
||||||
|
Tuple<F16, F16, F16, F16, F16>,
|
||||||
|
F16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
|
|
||||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||||
std::vector<std::unique_ptr<
|
std::vector<std::unique_ptr<
|
||||||
DeviceGroupedConvFwdMultipleABD<3,
|
DeviceGroupedConvFwdMultipleABD<3,
|
||||||
@@ -70,6 +115,21 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_n
|
|||||||
PassThrough,
|
PassThrough,
|
||||||
PassThrough,
|
PassThrough,
|
||||||
BiasNormalizeInInferClamp>>>& instances);
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
|
|
||||||
|
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||||
|
std::vector<std::unique_ptr<
|
||||||
|
DeviceGroupedConvFwdMultipleABD<3,
|
||||||
|
NDHWGC,
|
||||||
|
GKZYXC,
|
||||||
|
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||||
|
NDHWGK,
|
||||||
|
F16,
|
||||||
|
F16,
|
||||||
|
Tuple<F16, F16, F16, F16, F16>,
|
||||||
|
F16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace instance
|
} // namespace instance
|
||||||
|
|||||||
@@ -328,6 +328,8 @@ generate_sharded_instantiations(
|
|||||||
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
|
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
|
||||||
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||||
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||||
|
wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||||
|
wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||||
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
|
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||||
|
|
||||||
|
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||||
|
std::vector<
|
||||||
|
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||||
|
NHWGC,
|
||||||
|
GKYXC,
|
||||||
|
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||||
|
NHWGK,
|
||||||
|
BF16,
|
||||||
|
BF16,
|
||||||
|
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||||
|
BF16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<
|
||||||
|
2,
|
||||||
|
NHWGC,
|
||||||
|
GKYXC,
|
||||||
|
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||||
|
NHWGK,
|
||||||
|
ConvFwdDefault,
|
||||||
|
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||||
|
BiasNormalizeInInferClamp>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||||
|
|
||||||
|
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||||
|
std::vector<
|
||||||
|
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||||
|
NHWGC,
|
||||||
|
GKYXC,
|
||||||
|
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||||
|
NHWGK,
|
||||||
|
F16,
|
||||||
|
F16,
|
||||||
|
Tuple<F16, F16, F16, F16, F16>,
|
||||||
|
F16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<
|
||||||
|
2,
|
||||||
|
NHWGC,
|
||||||
|
GKYXC,
|
||||||
|
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||||
|
NHWGK,
|
||||||
|
ConvFwdDefault,
|
||||||
|
Tuple<F16, F16, F16, F16, F16>,
|
||||||
|
BiasNormalizeInInferClamp>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
@@ -309,6 +309,8 @@ generate_sharded_instantiations(
|
|||||||
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
|
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
|
||||||
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||||
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||||
|
wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||||
|
wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||||
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
|
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||||
|
|
||||||
|
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||||
|
std::vector<std::unique_ptr<
|
||||||
|
DeviceGroupedConvFwdMultipleABD<3,
|
||||||
|
NDHWGC,
|
||||||
|
GKZYXC,
|
||||||
|
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||||
|
NDHWGK,
|
||||||
|
BF16,
|
||||||
|
BF16,
|
||||||
|
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||||
|
BF16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<
|
||||||
|
3,
|
||||||
|
NDHWGC,
|
||||||
|
GKZYXC,
|
||||||
|
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||||
|
NDHWGK,
|
||||||
|
ConvFwdDefault,
|
||||||
|
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||||
|
BiasNormalizeInInferClamp>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||||
|
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
namespace instance {
|
||||||
|
|
||||||
|
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||||
|
|
||||||
|
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||||
|
std::vector<std::unique_ptr<
|
||||||
|
DeviceGroupedConvFwdMultipleABD<3,
|
||||||
|
NDHWGC,
|
||||||
|
GKZYXC,
|
||||||
|
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||||
|
NDHWGK,
|
||||||
|
F16,
|
||||||
|
F16,
|
||||||
|
Tuple<F16, F16, F16, F16, F16>,
|
||||||
|
F16,
|
||||||
|
PassThrough,
|
||||||
|
PassThrough,
|
||||||
|
BiasNormalizeInInferClamp>>>& instances)
|
||||||
|
{
|
||||||
|
add_device_operation_instances(instances,
|
||||||
|
device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<
|
||||||
|
3,
|
||||||
|
NDHWGC,
|
||||||
|
GKZYXC,
|
||||||
|
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||||
|
NDHWGK,
|
||||||
|
ConvFwdDefault,
|
||||||
|
Tuple<F16, F16, F16, F16, F16>,
|
||||||
|
BiasNormalizeInInferClamp>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace instance
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
Reference in New Issue
Block a user