[fmha-bwd] Implement group-mode persistent scheduling with optimized state management

This commit is contained in:
Ding, Yi
2026-04-16 22:27:25 -05:00
parent 1a9404ac96
commit 92f2ed758e
2 changed files with 206 additions and 106 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <chrono>
#include <cstring>
#include <functional>
#include <numeric>
@@ -243,30 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
const fmha_bwd_traits fmha_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
};
fmha_bwd_launcher launcher(fmha_traits);
const size_t ws_size = launcher.workspace_size;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
@@ -395,8 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
const auto t0_launcher = std::chrono::high_resolution_clock::now();
fmha_bwd_launcher launcher(fmha_bwd_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
});
const auto t1_launcher = std::chrono::high_resolution_clock::now();
const double launcher_ctor_ms =
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
const size_t ws_size = launcher.workspace_size;
ck_tile::DeviceMem ws_buf(ws_size);
ck_tile::gpu_timer prepare_ws_timer;
prepare_ws_timer.start(nullptr);
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
prepare_ws_timer.stop(nullptr);
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -442,7 +448,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
<< ", mask:" << mask << std::flush;
<< ", mask:" << mask
<< ", init:" << launcher_ctor_ms << "ms"
<< ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush;
auto fmha_args = [&]() {
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,