Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)

- Add support for direct store in epilogue instead of cshuffle
 - Add padding support for wave transfer without transpose
 - Add wave transfer with interleaved layout to support direct store
 - Enable new functionalities on GEMMs
 - Add optional new functionality support for grouped convolution fwd
 - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
This commit is contained in:
Enrico Degregori
2026-01-14 11:02:19 +01:00
committed by GitHub
parent 51027474af
commit 693ff3bbb3
20 changed files with 948 additions and 155 deletions

View File

@@ -125,6 +125,8 @@ set(GROUPED_CONV2D_FWD
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp
wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)

View File

@@ -0,0 +1,51 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault,
GemmMNKPadding,
BF16>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0,
GemmDefault,
BF16>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault,
GemmMNKPadding,
F16>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0,
GemmDefault,
F16>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck