diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 8090b2449a..01415c2ddd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -304,7 +304,23 @@ struct DeviceOperationInstanceFactory && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances( + op_ptrs); + } +#endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc index e47a876e1b..9a83e36b99 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -90,6 +90,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances( PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances( std::vector>>& instances); #endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances( std::vector>>& instances); #endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances( std::vector>>& instances); #endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 98bee66a91..146916cfd3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -11,6 +11,7 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp # NGCHW, GKYXC, NGKHW + xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp @@ -27,6 +28,7 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp # NGCHW, GKYXC, NGKHW + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp @@ -42,10 +44,12 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp # NGCHW, GKYXC, NGKHW + xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp # NGCHW, GKYXC, NGKHW + xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instance.cpp @@ -56,6 +60,7 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp # NGCHW, GKYXC, NGKHW + xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp new file mode 100644 index 0000000000..65e233ce08 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp new file mode 100644 index 0000000000..6ee6aa1e4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..88b5f30da5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..48cca9c3f3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp new file mode 100644 index 0000000000..14f00d8e88 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 1abd4fd9f6..25481e0d7f 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -64,6 +64,7 @@ using KernelTypes2d = ::testing::Types, std::tuple, std::tuple, std::tuple, + std::tuple, std::tuple>; using KernelTypes3d = ::testing::Types,