mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[fmha-bwd] Implement group-mode persistent scheduling with optimized state management
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
@@ -243,30 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
|
||||
|
||||
const fmha_bwd_traits fmha_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
};
|
||||
fmha_bwd_launcher launcher(fmha_traits);
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
@@ -395,8 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
const auto t0_launcher = std::chrono::high_resolution_clock::now();
|
||||
fmha_bwd_launcher launcher(fmha_bwd_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
});
|
||||
const auto t1_launcher = std::chrono::high_resolution_clock::now();
|
||||
const double launcher_ctor_ms =
|
||||
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
ck_tile::DeviceMem ws_buf(ws_size);
|
||||
ck_tile::gpu_timer prepare_ws_timer;
|
||||
prepare_ws_timer.start(nullptr);
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
prepare_ws_timer.stop(nullptr);
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -442,7 +448,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic
|
||||
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
|
||||
<< ", mask:" << mask << std::flush;
|
||||
<< ", mask:" << mask
|
||||
<< ", init:" << launcher_ctor_ms << "ms"
|
||||
<< ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush;
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
|
||||
Reference in New Issue
Block a user