diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index cfb426cbc2..d62b908e33 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -621,8 +621,11 @@ bwd_result fmha_bwd_run(mode_enum mode, {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision ck_tile::HostTensor p_dropped_hp_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision - ck_tile::HostTensor p_lp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + // p_lp_g_m_n low precision used for fwd (with rp_undrop) + ck_tile::HostTensor p_fwd_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + // p_lp_g_m_n low precision used for bwd (no rp_undrop) + ck_tile::HostTensor p_lp_host_ref({nhead, real_seqlen_q, real_seqlen_k}); ck_tile::index_t nr = nhead / nhead_k; @@ -762,8 +765,11 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::reference_batched_dropout_randval( randval_host_ref, wb, drop_seed, drop_offset); ck_tile::reference_batched_dropout( - p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, 1.f); p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); + p_dropped_hp_host_ref.ForEach( + [&](auto& self, const auto& idx) { self(idx) *= rp_undrop; }); + p_fwd_host_ref = p_dropped_hp_host_ref.template CopyAsType(); ck_tile::HostTensor randval_host_result( {nhead, real_seqlen_q, real_seqlen_k}); @@ -789,12 +795,13 @@ bwd_result fmha_bwd_run(mode_enum mode, } else { - p_lp_host_ref = p_hp_host_ref.template CopyAsType(); + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); + p_fwd_host_ref = p_lp_host_ref; } // O = P * V ck_tile::reference_batched_gemm( - p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + p_fwd_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n // clang-format off // permute @@ -900,7 +907,7 @@ bwd_result fmha_bwd_run(mode_enum mode, if(p_drop > 0) { ck_tile::reference_batched_dropout( - dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop); + dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, 1.f); } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) @@ -911,7 +918,8 @@ bwd_result fmha_bwd_run(mode_enum mode, { do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * - ck_tile::type_convert(o_host_refs[ref_idx](i0, i1, o)); + ck_tile::type_convert(o_host_refs[ref_idx](i0, i1, o)) * + p_undrop; } ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert(p_hp_host_refs[ref_idx](i0, i1, i2) * @@ -935,7 +943,12 @@ bwd_result fmha_bwd_run(mode_enum mode, auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m ck_tile:: reference_batched_gemm( - p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + p_t_lp_host_ref, + do_t_host_ref, + dv_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(rp_undrop)); // dv_g_n_o = p_lp_g_n_m@do_g_o_m // dQ = scale * dS@K^T auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n @@ -945,7 +958,7 @@ bwd_result fmha_bwd_run(mode_enum mode, dq_host_ref, ck_tile::identity{}, ck_tile::identity{}, - ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n + ck_tile::scales(scale * rp_undrop)); // dq_g_m_k = ds_g_m_n@k_g_k_n // dK = scale * dS^T@Q^T auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m @@ -956,7 +969,7 @@ bwd_result fmha_bwd_run(mode_enum mode, dk_host_ref, ck_tile::identity{}, ck_tile::identity{}, - ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m + ck_tile::scales(scale * rp_undrop)); // dk_g_n_k = ds_g_n_m@q_g_k_m ck_tile::HostTensor dq_host_result( {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx1201.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx1201.txt index 7fc521f762..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx1201.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx1201.txt @@ -1 +0,0 @@ -tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt index 7fc521f762..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt @@ -1 +0,0 @@ -tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt index 7fc521f762..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt @@ -1 +0,0 @@ -tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt index 7fc521f762..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt @@ -1 +0,0 @@ -tile_example_fmha_bwd -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 -prec=bf16 -d=96 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 125d32aad8..bc7d2323d0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -517,7 +517,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds - const index_t M0 = rows / M1; + const index_t M0 = integer_divide_ceil(rows, M1); const auto row_lens = make_tuple(M0, number{}); - const auto desc_0 = - make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto desc_0 = decltype(d0)( // set correct size (without padding) + d0.get_transforms(), + tensor_view_tmp.get_tensor_descriptor().get_element_space_size()); const auto desc_1 = transform_tensor_descriptor( desc_0, make_tuple(make_pass_through_transform(M0), diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc index a6defa8ccd..005d0ba083 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc @@ -20,10 +20,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128) { - if constexpr(std::is_same_v, F8>) - { - GTEST_SKIP() << "Skipping this test due to failures with F8"; - } constexpr int M = 128; constexpr int N = 128; constexpr int K = 128; @@ -48,11 +44,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128) { - if constexpr(std::is_same_v, F8>) - { - GTEST_SKIP() << "Skipping this test due to failures with F8"; - } - constexpr int M = 128; constexpr int N = 2048; constexpr int K = 128; @@ -77,11 +68,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128) { - if constexpr(std::is_same_v, F8>) - { - GTEST_SKIP() << "Skipping this test due to failures with F8"; - } - constexpr int M = 1024; constexpr int N = 128; constexpr int K = 128; @@ -106,11 +92,6 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128) { - if constexpr(std::is_same_v, F8>) - { - GTEST_SKIP() << "Skipping this test due to failures with F8"; - } - constexpr int M = 1024; constexpr int N = 2048; constexpr int K = 128;