mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Improve fmha_bwd tests performance (#2376)
* Avoid passing indices (std::vector) by value to host tensor's operator()
Each access requires 2 allocations and copies of the vector.
* Remove 1 unneeded vector copy from the slowest part of fmha_bwd's verification
* Compute ds_hp_host_ref in parallel
This sequntial ForEach is the slowest part of validation and it benefits
from parallel computation.
* Do not use ForEach for simple copy and conversion of large tensors
These tensors all have the same shape {nhead, real_seqlen_q, real_seqlen_k} and
can be copied/converted without complex computations of linear indices.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_hp_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
@@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
|
||||
Reference in New Issue
Block a user