diff --git a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp index f257e3d73e..d58b334e61 100644 --- a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp +++ b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp @@ -39,7 +39,7 @@ using fmha_bwd_convert_dq_0 = using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel; using convert_dq_trait_0 = - fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; + fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>; template <> void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, @@ -84,8 +84,8 @@ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits::AccDataType, typename FmhaBwdTypeConfig::KGradDataType, false, - false>>; + true>>; using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig::VGradDataType, false, - false>>; + true>>; using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile:: FmhaBwdDQDKDVKernel; @@ -146,8 +146,8 @@ using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, fmha_dropout_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, - false, - false, + true, + true, false, false, false>; @@ -171,7 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_() } // dot_do_o -using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits; +using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits; using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< typename FmhaBwdTypeConfig::ODataType, @@ -187,7 +187,7 @@ using fmha_bwd_dot_do_o_0 = using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel; -using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; +using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; template <> void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) @@ -251,10 +251,10 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && - (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && + (a.seqlen_q % 8 == 0 and a.seqlen_q % 8 == 0) && (a.seqlen_k % 8 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, @@ -264,13 +264,13 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, - false, - false, + true, + true, false, false, false>; using convert_dq_trait_ = - fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; + fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>; r = fmha_bwd_(s, a); return r; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 35b2f02e8a..21467f1d85 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1221,7 +1221,7 @@ struct FmhaBwdDQDKDVKernel const auto q_dram = pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -1232,7 +1232,7 @@ struct FmhaBwdDQDKDVKernel const auto k_dram = pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -1244,7 +1244,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); const auto lse_dram = [&]() { @@ -1805,7 +1805,7 @@ struct FmhaBwdOGradDotOKernel number<1>{}); return pad_tensor_view(o_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); const auto do_dram = [&]() { auto do_dram_naive = make_naive_tensor_view( @@ -1816,7 +1816,7 @@ struct FmhaBwdOGradDotOKernel number<1>{}); return pad_tensor_view(do_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); auto d_dram = [&]() { const auto d_dram_naive = make_naive_tensor_view_packed(