mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Adding remaining conv, dynamic_op, and scaleadd_scaleadd_relu flavors for grouped conv fwd (#3529)
* Adding remaining flavors for grouped conv fwd
As titled. Following variants are added:
- grouped_conv2d_fwd_dynamic_op
- grouped_conv3d_fwd_dynamic_op
- grouped_conv3d_fwd_bilinear
- grouped_conv3d_fwd_convscale
- grouped_conv3d_fwd_convinvscale
- grouped_conv3d_fwd_convscale_add
- grouped_conv3d_fwd_convscale_relu
- grouped_conv3d_fwd_scale
- grouped_conv3d_fwd_combconvscale
- grouped_conv3d_fwd_scaleadd_scaleadd_relu
* Fix incomplete parsing of types from source names in add_instance_library() cmakelists function so we don't build f8 on RDNA3.
* Do not build f8 / bf8 only flavor tests on RDNA3
* Make sure we have proper generic instances for all instance lists related to the post-ces extra flavors, with scalarPerVector = 1. Then disable all but one generic instance per instance list to reduce compile time.
* Post rebase fix: Template parameters for Grouped Conv Fwd Device Impl got tweaked upstream.
* adding int8 and fp16 overloads to the elementwise operations
* fixed copilot nits
* Addressing review comments:
- removed unnecessary examples for dynamic op
- removed unnecessary conv specalizations for all the flavors
- removed spurious bilinear and scale source files
* clang-format
* reduced no of tests
---------
Co-authored-by: Wojciech Laskowski <wojciech.laskowski@streamhpc.com>
[ROCm/composable_kernel commit: 2377a62837]
This commit is contained in:
committed by
GitHub
parent
09d443a7ad
commit
65c2e81817
@@ -104,7 +104,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Do not build WMMA grouped conv 3d fwd fp8 / bf8 for any targets except gfx12+
|
||||
if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "grouped_conv3d_fwd_wmma" AND (source_name MATCHES "_fp8_" OR source_name MATCHES "_bf8_"))
|
||||
if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "grouped_conv3d_fwd_wmma" AND source_name MATCHES "_(f8|fp8|bf8)_")
|
||||
message(DEBUG "removing grouped_conv3d_fwd_wmma fp8/bf8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV2D_FWD_DYNAMIC_OP
|
||||
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instance.cpp)
|
||||
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv2d_fwd_dynamic_op_instance ${GROUPED_CONV2D_FWD_DYNAMIC_OP})
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
ck::Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
ck::Tuple<>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_CONVINVSCALE
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_convinvscale_instance ${GROUPED_CONV3D_FWD_CONVINVSCALE})
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
ck::Tuple<>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvInvscale,
|
||||
F8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvInvscale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,12 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_CONVSCALE
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE})
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
ck::Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CombConvScale,
|
||||
F8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
CombConvScale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
BF8,
|
||||
F8,
|
||||
ck::Tuple<>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvScale,
|
||||
BF8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvScale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
BF8,
|
||||
BF8,
|
||||
ck::Tuple<>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvScale,
|
||||
BF8,
|
||||
BF8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_bf8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvScale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
BF8,
|
||||
ck::Tuple<>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvScale,
|
||||
F8,
|
||||
BF8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_bf8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvScale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
ck::Tuple<>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvScale,
|
||||
F8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvScale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_CONVSCALE_ADD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_convscale_add_instance ${GROUPED_CONV3D_FWD_CONVSCALE_ADD})
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_instance.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
using ConvScaleAdd = ck::tensor_operation::element_wise::ConvScaleAdd;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
ck::Tuple<F32>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvScaleAdd,
|
||||
F8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvScaleAdd>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
ConvScaleAdd>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_binary_outelementop_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
ConvScaleAdd>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_CONVSCALE_RELU
|
||||
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU})
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
ck::Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CombConvScaleRelu,
|
||||
F8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_f8_f32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
CombConvScaleRelu>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_outelementop_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu;
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
ck::Tuple<>,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ConvScaleRelu,
|
||||
F8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_outelementop_f8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
ConvScaleRelu>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_DYNAMIC_OP
|
||||
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_dynamic_op_instance ${GROUPED_CONV3D_FWD_DYNAMIC_OP})
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DynamicUnaryOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_dynamic_op_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
|
||||
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance ${GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU})
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAddScaleAddRelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_bf16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAddScaleAddRelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_scaleadd_relu_f16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user