mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
WMMA gemm_add_relu_add_layernorm (#2989)
* Summary:
- Refactor epilogue (with CShuffle) to support fused operations:
- EpilogueCShuffleBase holds common parts
- EpilogueCShuffle: runs CShuffle and write out
- EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out
- Extend thread transfer v7r3:
- Support for intermediate data type different from src and dst type
- New functionality to write to dst buffer and keep data (to be able to use them for additional operations)
* Adress review comments
[ROCm/composable_kernel commit: 4ebc48a3cd]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
@@ -78,6 +79,73 @@ void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_ins
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Relu + Add + Layernorm
|
||||
template <typename ALayout,
|
||||
@@ -136,29 +204,53 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user