diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 98e4ab89e9..28b69a7333 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -162,10 +162,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { -#if defined(ALWAYS_INVOKE_SPLITKV_KERNEL) - return fmha_fwd_splitkv(traits, args, config); -#else - if(1 < args.num_splits) + if(1 < args.num_splits && args.p_drop == 0.0f) { return fmha_fwd_splitkv(traits, args, config); } @@ -173,8 +170,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits, { return fmha_fwd(traits, args, config); } -#endif -}; +} template bool run(const ck_tile::ArgParser& arg_parser) @@ -385,22 +381,13 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1}); ck_tile::HostTensor lse_acc_host( - 1 <= num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} - : std::array{1, 1, 1, 1}); - APP_DEBUG_STMTS - { - std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", " << nhead << ", " - << max_seqlen_q << std::endl; - } + 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} + : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 <= num_splits + 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} : std::array{1, 1, 1, 1, 1}); - APP_DEBUG_STMTS - { - std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", " << nhead - << ", " << shape_seqlen_q << ", " << hdim_v << std::endl; - } + // self define lse data layout as [batch, nhead, max_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{batch, nhead, max_seqlen_q} @@ -669,7 +656,6 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); - lse_acc_buf.FromDevice(lse_acc_host.data()); randval_buf.FromDevice(randval_host.data()); float p_undrop = 1.0 - p_drop; uint8_t p_undrop_in_uint8_t = @@ -678,124 +664,11 @@ bool run(const ck_tile::ArgParser& arg_parser) bool pass = true; - APP_DEBUG_STMTS - { - printf("\n"); - printf("[POYENC][HOST] lse shape: %d, %d, %d, %d\n", - num_splits, - shape_batch, - nhead, - shape_seqlen_q); - } - for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - APP_DEBUG_STMTS { printf("[POYENC][HOST] wb: %d\n", wb); } - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - APP_DEBUG_STMTS - { - // lse_acc_host shape: num_splits, shape_batch, nhead, shape_seqlen_q - for(int i_split = 0; i_split < num_splits; ++i_split) - { - printf("[POYENC][HOST] i_split: %d\n", i_split); - printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0) = ", i_split, wb); - for(int row = 0; row < real_seqlen_q; ++row) - { - printf( - "%11.7f", - // printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0): %11.7f\n", i_split, wb, - lse_acc_host(i_split, wb, 0, row)); - } - printf("\n"); - } - } - - APP_DEBUG_STMTS - { - ck_tile::HostTensor lse_max({real_seqlen_q}); - for(int row = 0; row < real_seqlen_q; ++row) - { - lse_max(row) = -ck_tile::numeric::infinity(); - for(int i_split = 0; i_split < num_splits; ++i_split) - { - if(lse_max(row) < lse_acc_host(i_split, wb, 0, row)) - { - lse_max(row) = lse_acc_host(i_split, wb, 0, row); - } - } - } - printf("[POYENC][HOST] lse_max: "); - for(int row = 0; row < real_seqlen_q; ++row) - { - printf("%11.7f", lse_max(row)); - } - printf("\n"); - - static const auto get_validated_m = [](LSEDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - return raw_m == -ck_tile::numeric::infinity() - ? ck_tile::type_convert(0.f) - : raw_m; - }; - - ck_tile::HostTensor lse_sum({shape_seqlen_q}); - for(int row = 0; row < real_seqlen_q; ++row) - { - lse_sum(row) = 0; - for(int i_split = 0; i_split < num_splits; ++i_split) - { - lse_sum(row) += ck_tile::exp(lse_acc_host(i_split, wb, 0, row) - - get_validated_m(lse_max(row))); - } - } - printf("[POYENC][HOST] lse_sum: "); - for(int row = 0; row < real_seqlen_q; ++row) - { - printf("%11.7f", lse_sum(row)); - } - printf("\n"); - - ck_tile::HostTensor lse_logsum({real_seqlen_q}); - for(int row = 0; row < real_seqlen_q; ++row) - { - if(lse_sum(row) == 0.f || lse_sum(row) != lse_sum(row)) - { - lse_logsum(row) = ck_tile::numeric::infinity(); - } - else - { - lse_logsum(row) = ck_tile::log(lse_sum(row)) + get_validated_m(lse_max(row)); - } - } - - for(int row = 0; row < real_seqlen_q; ++row) - { - if(lse_logsum(row) == ck_tile::numeric::infinity()) - { - lse_logsum(row) = -ck_tile::numeric::infinity(); - } - } - - // lse_host shape: [batch, nhead, max_seqlen_q] - printf("[POYENC][DEVICE] lse_host: "); - for(int row = 0; row < real_seqlen_q; ++row) - { - printf("%11.7f", lse_host(wb, 0, row)); - } - printf("\n"); - - printf("[POYENC][HOST] lse_logsum: "); - for(int row = 0; row < real_seqlen_q; ++row) - { - printf("%11.7f", lse_logsum(row)); - } - printf("\n"); - } - // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); @@ -990,34 +863,8 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format on auto [rtol, atol] = get_elimit(init_method); - bool cur_pass = true; - if(lse) - { - ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach( - [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); - - cur_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); - // if (cur_pass) std::cout << "LSE pass" << std::endl; - pass &= cur_pass; - if(!cur_pass) - { - std::cerr << "LSE mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - } - } -#if 1 - cur_pass = ck_tile::check_err( + bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); - // if (cur_pass) std::cout << "OUT pass" << std::endl; pass &= cur_pass; if(!cur_pass) { @@ -1029,7 +876,32 @@ bool run(const ck_tile::ArgParser& arg_parser) break; } -#endif + + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); + + bool lse_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= lse_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;