From e156b5aebb8f5c1bec8a87e7cecb54cd7b43f30e Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Tue, 24 Jun 2025 20:45:24 +0600 Subject: [PATCH] Improve fmha_bwd tests performance (#2376) * Avoid passing indices (std::vector) by value to host tensor's operator() Each access requires 2 allocations and copies of the vector. * Remove 1 unneeded vector copy from the slowest part of fmha_bwd's verification * Compute ds_hp_host_ref in parallel This sequntial ForEach is the slowest part of validation and it benefits from parallel computation. * Do not use ForEach for simple copy and conversion of large tensors These tensors all have the same shape {nhead, real_seqlen_q, real_seqlen_k} and can be copied/converted without complex computations of linear indices. [ROCm/composable_kernel commit: 77123600ee4b6fae077a2145b68b00a8b2ce9460] --- example/ck_tile/01_fmha/fmha_bwd.cpp | 47 +++++++++------------- include/ck/library/utility/host_tensor.hpp | 6 +-- include/ck_tile/host/host_tensor.hpp | 9 +++-- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index eaf99529f3..3b9cf09eb2 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_bwd.hpp" #include "ck_tile/host.hpp" @@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(p_drop > 0) { - p_hp_host_ref.ForEach( - [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + p_dropped_hp_host_ref = p_hp_host_ref; randval_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); }); ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); } else { - p_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); } // O = P * V @@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser) } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { - AccDataType do_dot_o = 0; - for(int o = 0; o < hdim_v; o++) - { - auto idx_gmo = idx_gmn; - idx_gmo[2] = o; - do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * - ck_tile::type_convert(o_host_refs[wb](idx_gmo)); - } - self(idx_gmn) = ck_tile::type_convert( - p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); - }); + ck_tile::make_ParallelTensorFunctor( + [&](auto i0, auto i1, auto i2) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + } + ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( + p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + }, + ds_hp_host_ref.mDesc.get_lengths()[0], + ds_hp_host_ref.mDesc.get_lengths()[1], + ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); if(use_dbias) { - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + dbias_host_ref = ds_hp_host_ref.template CopyAsType(); } - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 06e33afd20..286dffc36c 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -167,7 +167,7 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } - std::size_t GetOffsetFromMultiIndex(std::vector iss) const + std::size_t GetOffsetFromMultiIndex(const std::vector& iss) const { return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } @@ -600,12 +600,12 @@ struct Tensor ck::packed_size_v>]; } - T& operator()(std::vector idx) + T& operator()(const std::vector& idx) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } - const T& operator()(std::vector idx) const + const T& operator()(const std::vector& idx) const { return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index deaa158d50..b8c764809c 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -230,7 +230,7 @@ struct HostTensorDescriptor * @param iss Vector containing the multi-dimensional indices * @return The calculated linear offset as a size_t */ - std::size_t GetOffsetFromMultiIndex(std::vector iss) const + std::size_t GetOffsetFromMultiIndex(const std::vector& iss) const { return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } @@ -540,9 +540,12 @@ struct HostTensor return mData[GetOffsetFromMultiIndex(is...)]; } - T& operator()(std::vector idx) { return mData[GetOffsetFromMultiIndex(idx)]; } + T& operator()(const std::vector& idx) + { + return mData[GetOffsetFromMultiIndex(idx)]; + } - const T& operator()(std::vector idx) const + const T& operator()(const std::vector& idx) const { return mData[GetOffsetFromMultiIndex(idx)]; }