WMMA gemm_add_relu_add_layernorm (#2989)

* Summary:

 - Refactor epilogue (with CShuffle) to support fused operations:
    - EpilogueCShuffleBase holds common parts
    - EpilogueCShuffle: runs CShuffle and write out
    - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out

 - Extend thread transfer v7r3:
    - Support for intermediate data type different from src and dst type
    - New functionality to write to dst buffer and keep data (to be able to use them for additional operations)

* Adress review comments
This commit is contained in:
Enrico Degregori
2025-10-31 19:19:26 +01:00
committed by GitHub
parent e9596228ff
commit 4ebc48a3cd
23 changed files with 2678 additions and 332 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -167,6 +167,12 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
Tensor<HDataType> h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "h_m_n: " << h_m_n.mDesc << std::endl;
switch(init_method)
{
case 0: break;
@@ -312,9 +318,8 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
float gb_per_sec = num_byte / 1.E6 / ave_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
@@ -333,8 +338,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
}
else
{
if(time_kernel)
std::cout << op_name << " does not support this problem" << std::endl;
std::cout << op_name << " does not support this problem" << std::endl;
}
}