Transpose -> transpose

This commit is contained in:
danyao12
2024-05-29 16:54:26 +08:00
parent 58f61716b5
commit 87f73f30e8
2 changed files with 10 additions and 10 deletions

View File

@@ -778,7 +778,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// dP = dO@V x Z w/ dropout
// dP = dO@V w/o dropout
auto v_t_host_ref = v_host_refs[wb].Transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
@@ -815,13 +815,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout
auto p_t_lp_host_ref = p_lp_host_refs[wb].Transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
auto do_t_host_ref = do_host_ref.Transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
ck_tile::reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
// dQ = scale * dS@K^T
auto k_t_host_ref = k_host_refs[wb].Transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
ds_lp_host_ref,
k_t_host_ref,
@@ -831,8 +831,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
// dK = scale * dS^T@Q^T
auto ds_t_lp_host_ref = ds_lp_host_ref.Transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[wb].Transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
ds_t_lp_host_ref,
q_t_host_ref,

View File

@@ -481,7 +481,7 @@ struct HostTensor
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
HostTensor<T> Transpose(std::vector<size_t> axes = {}) const
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
{
if(axes.empty())
{
@@ -491,7 +491,7 @@ struct HostTensor
if(axes.size() != mDesc.get_num_of_dimension())
{
throw std::runtime_error(
"HostTensor::Transpose(): size of axes must match tensor dimension");
"HostTensor::transpose(): size of axes must match tensor dimension");
}
std::vector<size_t> tlengths, tstrides;
for(const auto& axis : axes)
@@ -504,9 +504,9 @@ struct HostTensor
return ret;
}
HostTensor<T> Transpose(std::vector<size_t> axes = {})
HostTensor<T> transpose(std::vector<size_t> axes = {})
{
return const_cast<HostTensor<T> const*>(this)->Transpose(axes);
return const_cast<HostTensor<T> const*>(this)->transpose(axes);
}
typename Data::iterator begin() { return mData.begin(); }