diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b9d653b059..2ed135d638 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -869,21 +869,6 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::scales{scale_p}; - else - return ck_tile::identity{}; - }(); - - auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); - else - return ck_tile::identity{}; - }(); - const auto init_traits = [&](auto& traits) { traits.hdim_q = hdim_q; traits.hdim_v = hdim_v; @@ -1088,13 +1073,28 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); + + auto p_compute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::identity{}; + }(); + float p_undrop = 1.0 - p_drop; uint8_t p_undrop_in_uint8_t = uint8_t(std::floor(p_undrop * std::numeric_limits::max())); float rp_undrop = 1.0 / p_undrop; bool pass = true; - for(ck_tile::index_t wb = 0; wb < batch; ++wb) { const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];