WMMA gemm_add_relu_add_layernorm (#2989)

* Summary:

 - Refactor epilogue (with CShuffle) to support fused operations:
    - EpilogueCShuffleBase holds common parts
    - EpilogueCShuffle: runs CShuffle and write out
    - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out

 - Extend thread transfer v7r3:
    - Support for intermediate data type different from src and dst type
    - New functionality to write to dst buffer and keep data (to be able to use them for additional operations)

* Adress review comments

[ROCm/composable_kernel commit: 4ebc48a3cd]
This commit is contained in:
Enrico Degregori
2025-10-31 19:19:26 +01:00
committed by GitHub
parent 2136eddf8a
commit e6be7bcc2a
23 changed files with 2678 additions and 332 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -15,6 +15,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Row,
@@ -78,6 +79,73 @@ void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_ins
PassThrough,
AddReluAdd,
PassThrough>>>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
#endif
// GEMM + Add + Relu + Add + Layernorm
template <typename ALayout,
@@ -136,29 +204,53 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
#endif
}
}