|
|
|
|
@@ -621,8 +621,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
|
|
|
|
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
|
|
|
|
|
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
|
|
|
|
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
|
|
|
|
|
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
|
|
|
|
|
|
|
|
|
// p_lp_g_m_n low precision used for fwd (with rp_undrop)
|
|
|
|
|
ck_tile::HostTensor<GemmDataType> p_fwd_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
|
|
|
|
// p_lp_g_m_n low precision used for bwd (no rp_undrop)
|
|
|
|
|
ck_tile::HostTensor<GemmDataType> p_lp_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
|
|
|
|
|
|
|
|
|
ck_tile::index_t nr = nhead / nhead_k;
|
|
|
|
|
|
|
|
|
|
@@ -762,8 +765,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
ck_tile::reference_batched_dropout_randval(
|
|
|
|
|
randval_host_ref, wb, drop_seed, drop_offset);
|
|
|
|
|
ck_tile::reference_batched_dropout(
|
|
|
|
|
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
|
|
|
|
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, 1.f);
|
|
|
|
|
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
|
|
|
|
p_dropped_hp_host_ref.ForEach(
|
|
|
|
|
[&](auto& self, const auto& idx) { self(idx) *= rp_undrop; });
|
|
|
|
|
p_fwd_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
|
|
|
|
|
|
|
|
|
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
|
|
|
|
|
{nhead, real_seqlen_q, real_seqlen_k});
|
|
|
|
|
@@ -789,12 +795,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
|
|
|
|
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
|
|
|
|
p_fwd_host_ref = p_lp_host_ref;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// O = P * V
|
|
|
|
|
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
|
|
|
|
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
|
|
|
|
p_fwd_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
// permute
|
|
|
|
|
@@ -900,7 +907,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
if(p_drop > 0)
|
|
|
|
|
{
|
|
|
|
|
ck_tile::reference_batched_dropout(
|
|
|
|
|
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop);
|
|
|
|
|
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, 1.f);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
|
|
|
|
@@ -911,7 +918,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
{
|
|
|
|
|
do_dot_o +=
|
|
|
|
|
ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
|
|
|
|
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o));
|
|
|
|
|
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o)) *
|
|
|
|
|
p_undrop;
|
|
|
|
|
}
|
|
|
|
|
ds_hp_host_ref(i0, i1, i2) =
|
|
|
|
|
ck_tile::type_convert<AccDataType>(p_hp_host_refs[ref_idx](i0, i1, i2) *
|
|
|
|
|
@@ -935,7 +943,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
|
|
|
|
ck_tile::
|
|
|
|
|
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
|
|
|
|
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
|
|
|
|
p_t_lp_host_ref,
|
|
|
|
|
do_t_host_ref,
|
|
|
|
|
dv_host_ref,
|
|
|
|
|
ck_tile::identity{},
|
|
|
|
|
ck_tile::identity{},
|
|
|
|
|
ck_tile::scales(rp_undrop)); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
|
|
|
|
|
|
|
|
|
// dQ = scale * dS@K^T
|
|
|
|
|
auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
|
|
|
|
@@ -945,7 +958,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
dq_host_ref,
|
|
|
|
|
ck_tile::identity{},
|
|
|
|
|
ck_tile::identity{},
|
|
|
|
|
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
|
|
|
|
ck_tile::scales(scale * rp_undrop)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
|
|
|
|
|
|
|
|
|
// dK = scale * dS^T@Q^T
|
|
|
|
|
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
|
|
|
|
@@ -956,7 +969,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
|
|
|
|
dk_host_ref,
|
|
|
|
|
ck_tile::identity{},
|
|
|
|
|
ck_tile::identity{},
|
|
|
|
|
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
|
|
|
|
ck_tile::scales(scale * rp_undrop)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
|
|
|
|
|
|
|
|
|
ck_tile::HostTensor<QGradDataType> dq_host_result(
|
|
|
|
|
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
|
|
|
|
|