From 87f73f30e8ee7f6d1659e598e6e8a3f352d8f579 Mon Sep 17 00:00:00 2001 From: danyao12 Date: Wed, 29 May 2024 16:54:26 +0800 Subject: [PATCH] Transpose -> transpose --- example/ck_tile/01_fmha/fmha_bwd.cpp | 12 ++++++------ include/ck_tile/host/host_tensor.hpp | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9627c2bbf5..b1249b5eda 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -778,7 +778,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // dP = dO@V x Z w/ dropout // dP = dO@V w/o dropout - auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o ck_tile::reference_batched_gemm( do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o @@ -815,13 +815,13 @@ bool run(const ck_tile::ArgParser& arg_parser) // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout - auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m - auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m ck_tile::reference_batched_gemm( p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m // dQ = scale * dS@K^T - auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n ck_tile::reference_batched_gemm( ds_lp_host_ref, k_t_host_ref, @@ -831,8 +831,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n // dK = scale * dS^T@Q^T - auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m - auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m ck_tile::reference_batched_gemm( ds_t_lp_host_ref, q_t_host_ref, diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index bb60fc8172..43405ee69b 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -481,7 +481,7 @@ struct HostTensor return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } - HostTensor Transpose(std::vector axes = {}) const + HostTensor transpose(std::vector axes = {}) const { if(axes.empty()) { @@ -491,7 +491,7 @@ struct HostTensor if(axes.size() != mDesc.get_num_of_dimension()) { throw std::runtime_error( - "HostTensor::Transpose(): size of axes must match tensor dimension"); + "HostTensor::transpose(): size of axes must match tensor dimension"); } std::vector tlengths, tstrides; for(const auto& axis : axes) @@ -504,9 +504,9 @@ struct HostTensor return ret; } - HostTensor Transpose(std::vector axes = {}) + HostTensor transpose(std::vector axes = {}) { - return const_cast const*>(this)->Transpose(axes); + return const_cast const*>(this)->transpose(axes); } typename Data::iterator begin() { return mData.begin(); }