Merge branch 'develop' into gptoss_sink

This commit is contained in:
Linjun-AMD
2025-12-29 09:57:27 +08:00
committed by GitHub
8 changed files with 29 additions and 37 deletions

View File

@@ -621,8 +621,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
// p_lp_g_m_n low precision used for fwd (with rp_undrop)
ck_tile::HostTensor<GemmDataType> p_fwd_host_ref({nhead, real_seqlen_q, real_seqlen_k});
// p_lp_g_m_n low precision used for bwd (no rp_undrop)
ck_tile::HostTensor<GemmDataType> p_lp_host_ref({nhead, real_seqlen_q, real_seqlen_k});
ck_tile::index_t nr = nhead / nhead_k;
@@ -762,8 +765,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::reference_batched_dropout_randval(
randval_host_ref, wb, drop_seed, drop_offset);
ck_tile::reference_batched_dropout(
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, 1.f);
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
p_dropped_hp_host_ref.ForEach(
[&](auto& self, const auto& idx) { self(idx) *= rp_undrop; });
p_fwd_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
{nhead, real_seqlen_q, real_seqlen_k});
@@ -789,12 +795,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
}
else
{
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
p_fwd_host_ref = p_lp_host_ref;
}
// O = P * V
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
p_fwd_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
// clang-format off
// permute
@@ -900,7 +907,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
if(p_drop > 0)
{
ck_tile::reference_batched_dropout(
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop);
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, 1.f);
}
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
@@ -911,7 +918,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
{
do_dot_o +=
ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o));
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o)) *
p_undrop;
}
ds_hp_host_ref(i0, i1, i2) =
ck_tile::type_convert<AccDataType>(p_hp_host_refs[ref_idx](i0, i1, i2) *
@@ -935,7 +943,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
ck_tile::
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
p_t_lp_host_ref,
do_t_host_ref,
dv_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(rp_undrop)); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
// dQ = scale * dS@K^T
auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
@@ -945,7 +958,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
ck_tile::scales(scale * rp_undrop)); // dq_g_m_k = ds_g_m_n@k_g_k_n
// dK = scale * dS^T@Q^T
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
@@ -956,7 +969,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
dk_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
ck_tile::scales(scale * rp_undrop)); // dk_g_n_k = ds_g_n_m@q_g_k_m
ck_tile::HostTensor<QGradDataType> dq_host_result(
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k

View File

@@ -1 +0,0 @@
tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1

View File

@@ -1 +0,0 @@
tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1

View File

@@ -1 +0,0 @@
tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1

View File

@@ -1 +0,0 @@
tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1

View File

@@ -517,7 +517,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
"wrong!");
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
static_assert(NWarp == 4);
static_assert(MWarp == 1);
using CWarpTensor = typename WG::CWarpTensor;

View File

@@ -113,11 +113,13 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const index_t M0 = integer_divide_ceil(rows, M1);
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_0 = decltype(d0)( // set correct size (without padding)
d0.get_transforms(),
tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),

View File

@@ -20,10 +20,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128)
{
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
constexpr int M = 128;
constexpr int N = 128;
constexpr int K = 128;
@@ -48,11 +44,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128)
{
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
constexpr int M = 128;
constexpr int N = 2048;
constexpr int K = 128;
@@ -77,11 +68,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128)
{
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
constexpr int M = 1024;
constexpr int N = 128;
constexpr int K = 128;
@@ -106,11 +92,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096)
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128)
{
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
{
GTEST_SKIP() << "Skipping this test due to failures with F8";
}
constexpr int M = 1024;
constexpr int N = 2048;
constexpr int K = 128;