mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
WMMA gemm_add_relu_add_layernorm (#2989)
* Summary:
- Refactor epilogue (with CShuffle) to support fused operations:
- EpilogueCShuffleBase holds common parts
- EpilogueCShuffle: runs CShuffle and write out
- EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out
- Extend thread transfer v7r3:
- Support for intermediate data type different from src and dst type
- New functionality to write to dst buffer and keep data (to be able to use them for additional operations)
* Adress review comments
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -167,6 +167,12 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
Tensor<HDataType> h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "h_m_n: " << h_m_n.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
@@ -312,9 +318,8 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
if(time_kernel)
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << std::endl;
|
||||
|
||||
if(ave_time < best_ave_time)
|
||||
{
|
||||
@@ -333,8 +338,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
if(time_kernel)
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user