Add 1D instances

This commit is contained in:
kiefer
2025-08-28 15:57:30 +00:00
parent 68f9e73b5e
commit 78635fd74e
7 changed files with 215 additions and 49 deletions

View File

@@ -735,7 +735,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
// add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -745,7 +745,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
// add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
@@ -753,7 +754,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
// add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_int8_instances(
op_ptrs);
}
#endif
}

View File

@@ -10,51 +10,51 @@ namespace instance {
#ifdef CK_ENABLE_BF16
// grouped conv1d forward, GNWC/GKXC/GNWK
// void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
// GNWC,
// GKXC,
// Empty_Tuple,
// GNWK,
// BF16,
// BF16,
// Empty_Tuple,
// BF16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
// void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
// GNWC,
// GKXC,
// Empty_Tuple,
// GNWK,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
// void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_int8_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
// GNWC,
// GKXC,
// Empty_Tuple,
// GNWK,
// int8_t,
// int8_t,
// Empty_Tuple,
// int8_t,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16

View File

@@ -1,7 +1,11 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_conv1d_fwd_instance
xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp
xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp
wmma/device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_bf16_instance.cpp
wmma/device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_f16_instance.cpp
wmma/device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_int8_instance.cpp
)

View File

@@ -0,0 +1,55 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_fwd_wmma_cshufflev3_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,11 +1,6 @@
# TODO: Put the 3d instances back
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp)
if((GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9"))
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
else()
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
endif()
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9")