mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[fmha-bwd] Implement group-mode persistent scheduling with optimized state management
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
@@ -243,30 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
|
||||
|
||||
const fmha_bwd_traits fmha_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
};
|
||||
fmha_bwd_launcher launcher(fmha_traits);
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
@@ -395,8 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
const auto t0_launcher = std::chrono::high_resolution_clock::now();
|
||||
fmha_bwd_launcher launcher(fmha_bwd_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
});
|
||||
const auto t1_launcher = std::chrono::high_resolution_clock::now();
|
||||
const double launcher_ctor_ms =
|
||||
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
ck_tile::DeviceMem ws_buf(ws_size);
|
||||
ck_tile::gpu_timer prepare_ws_timer;
|
||||
prepare_ws_timer.start(nullptr);
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
prepare_ws_timer.stop(nullptr);
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -442,7 +448,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic
|
||||
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
|
||||
<< ", mask:" << mask << std::flush;
|
||||
<< ", mask:" << mask
|
||||
<< ", init:" << launcher_ctor_ms << "ms"
|
||||
<< ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush;
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
|
||||
@@ -27,6 +27,31 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Per-CU state for group-mode deterministic persistent scheduling.
|
||||
// Packed into a single array to reduce kargs pointer count (5 pointers → 1).
|
||||
// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad).
|
||||
struct alignas(16) FmhaBwdGroupPersistentCuState
|
||||
{
|
||||
index_t w_lo; // global position of this CU's first K-chunk (= pb + head*hw + c*sq)
|
||||
index_t w_hi; // global position of next CU's first K-chunk (exclusive upper bound)
|
||||
index_t ibatch; // first batch this CU touches (batch_size = no work sentinel)
|
||||
index_t isplit; // isplit for the first (batch, head) this CU touches
|
||||
index_t head_start; // head index for the first batch this CU touches
|
||||
index_t c_start; // chunk index for the first (batch, head) this CU touches
|
||||
// 8 bytes implicit padding
|
||||
};
|
||||
|
||||
// Per-batch precomputed values used in the group-mode persistent dispatch loop.
|
||||
// Avoids per-iteration reads from seqstart_q/k_ptr and nsplits_ptr.
|
||||
// alignas(16): sizeof == 16 (3×4 + 4 pad), fits in a single 128-bit load.
|
||||
struct alignas(16) FmhaBwdBatchState
|
||||
{
|
||||
index_t sq; // seqlen_q for this batch (seqstart_q[b+1] - seqstart_q[b])
|
||||
index_t nc; // number of K-chunks: ceil(seqlen_k / kN0)
|
||||
index_t nsplits; // dq_acc split count for this batch
|
||||
// 4 bytes implicit padding
|
||||
};
|
||||
|
||||
template <typename AccDataType, bool kIsGroupMode, bool kIsDeterministic>
|
||||
struct FmhaBwdWorkspaceManager
|
||||
{
|
||||
@@ -77,10 +102,20 @@ struct FmhaBwdWorkspaceManager
|
||||
return integer_least_multiple(sizeof(index_t) * (batch + 1), ALIGNMENT);
|
||||
return 0;
|
||||
}
|
||||
CK_TILE_HOST static size_t GetCuStartIbatchSize(const int num_cus)
|
||||
// cu_state[num_cus]: per-CU persistent state packed into one array (group det only).
|
||||
// Replaces separate cu_start_ibatch / cu_wlo / cu_isplit / cu_head_start / cu_c_start arrays.
|
||||
CK_TILE_HOST static size_t GetCuStateSize(const int num_cus)
|
||||
{
|
||||
if constexpr(kIsGroupMode && kIsDeterministic)
|
||||
return integer_least_multiple(sizeof(index_t) * num_cus, ALIGNMENT);
|
||||
return integer_least_multiple(sizeof(FmhaBwdGroupPersistentCuState) * num_cus,
|
||||
ALIGNMENT);
|
||||
return 0;
|
||||
}
|
||||
// batch_state[batch]: per-batch sq/nc/nsplits for group det dispatch loop.
|
||||
CK_TILE_HOST static size_t GetBatchStateSize(const int batch)
|
||||
{
|
||||
if constexpr(kIsGroupMode && kIsDeterministic)
|
||||
return integer_least_multiple(sizeof(FmhaBwdBatchState) * batch, ALIGNMENT);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -89,8 +124,11 @@ struct FmhaBwdWorkspaceManager
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch) +
|
||||
GetPrefixBatchSize(batch) + GetCuStartIbatchSize(get_num_cus());
|
||||
const size_t raw = GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) +
|
||||
GetDqAccOffsetsSize(batch) + GetPrefixBatchSize(batch) +
|
||||
GetCuStateSize(get_num_cus()) + GetBatchStateSize(batch);
|
||||
// Pad to 4K so dq_acc buffer always starts on a page-aligned boundary.
|
||||
return integer_least_multiple(raw, static_cast<size_t>(4096));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; }
|
||||
@@ -103,10 +141,14 @@ struct FmhaBwdWorkspaceManager
|
||||
{
|
||||
return GetDqAccSplitsSize<false>(batch) + GetDqAccOffsetsSize(batch);
|
||||
}
|
||||
CK_TILE_HOST static size_t GetCuStartIbatchOffset(const int batch)
|
||||
CK_TILE_HOST static size_t GetCuStateOffset(const int batch)
|
||||
{
|
||||
return GetPrefixBatchOffset(batch) + GetPrefixBatchSize(batch);
|
||||
}
|
||||
CK_TILE_HOST static size_t GetBatchStateOffset(const int batch)
|
||||
{
|
||||
return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus());
|
||||
}
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
|
||||
{
|
||||
@@ -155,9 +197,10 @@ struct FmhaBwdWorkspaceManager
|
||||
auto* prefix_batch = reinterpret_cast<index_t*>(reinterpret_cast<char*>(cpu_ws) +
|
||||
GetDqAccSplitsSize<false>(batch_size) +
|
||||
GetDqAccOffsetsSize(batch_size));
|
||||
auto* cu_start_ibatch = reinterpret_cast<index_t*>(
|
||||
reinterpret_cast<char*>(cpu_ws) + GetDqAccSplitsSize<false>(batch_size) +
|
||||
GetDqAccOffsetsSize(batch_size) + GetPrefixBatchSize(batch_size));
|
||||
auto* cu_states = reinterpret_cast<FmhaBwdGroupPersistentCuState*>(
|
||||
reinterpret_cast<char*>(cpu_ws) + GetCuStateOffset(batch_size));
|
||||
auto* batch_states = reinterpret_cast<FmhaBwdBatchState*>(
|
||||
reinterpret_cast<char*>(cpu_ws) + GetBatchStateOffset(batch_size));
|
||||
|
||||
prefix_batch[0] = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
@@ -168,15 +211,19 @@ struct FmhaBwdWorkspaceManager
|
||||
}
|
||||
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
|
||||
|
||||
// Step 2: compute nsplits[b] = per_batch_max_cus[b] (for dq_acc split dimension)
|
||||
// Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch)
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0;
|
||||
nsplits[b] = 1 + (rest_workload > 0 && target_w > 0
|
||||
const index_t ns = 1 + (rest_workload > 0 && target_w > 0
|
||||
? integer_divide_ceil(rest_workload, target_w)
|
||||
: 0);
|
||||
nsplits[b] = ns;
|
||||
batch_states[b].sq = sq;
|
||||
batch_states[b].nc = nc;
|
||||
batch_states[b].nsplits = ns;
|
||||
}
|
||||
|
||||
// Step 3: compute per-batch dq_acc offsets (compact layout, depends on nsplits)
|
||||
@@ -191,18 +238,65 @@ struct FmhaBwdWorkspaceManager
|
||||
offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
|
||||
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
|
||||
|
||||
// Step 4: fill cu_start_ibatch via two-pointer scan O(batch + num_cus)
|
||||
// Step 4: fill cu_states via two-pointer scan.
|
||||
// w_lo = global position of the first K-chunk: pb + head_start*hw + c_start*sq.
|
||||
// This makes w_chunk track true global K-chunk positions on GPU, so w_chunk < w_hi
|
||||
// correctly identifies boundaries without off-by-one overlap between adjacent CUs.
|
||||
// w_hi is set in a post-pass to cu_states[c+1].w_lo.
|
||||
index_t cu_lo = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t hw = nc * sq;
|
||||
const index_t pb = prefix_batch[b];
|
||||
const index_t cu_hi =
|
||||
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
|
||||
for(index_t c = cu_lo; c < cu_hi; ++c)
|
||||
cu_start_ibatch[c] = b;
|
||||
{
|
||||
const index_t w_lo = c * target_w;
|
||||
cu_states[c].ibatch = b;
|
||||
if(hw > 0)
|
||||
{
|
||||
const index_t head_start =
|
||||
max(static_cast<index_t>((w_lo - pb) / hw), index_t(0));
|
||||
const index_t w_head = pb + head_start * hw;
|
||||
const index_t wc_start = max(w_lo - w_head, index_t(0));
|
||||
const index_t c_start =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
|
||||
cu_states[c].isplit =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0;
|
||||
cu_states[c].head_start = head_start;
|
||||
cu_states[c].c_start = c_start;
|
||||
// w_lo = true global start of first K-chunk for this CU
|
||||
cu_states[c].w_lo = pb + head_start * hw + c_start * sq;
|
||||
}
|
||||
else
|
||||
{
|
||||
cu_states[c].isplit = 0;
|
||||
cu_states[c].head_start = 0;
|
||||
cu_states[c].c_start = 0;
|
||||
cu_states[c].w_lo = pb; // hw==0: degenerate, w_lo=batch start
|
||||
}
|
||||
}
|
||||
cu_lo = cu_hi;
|
||||
}
|
||||
// Inactive CUs: use total_w as w_lo sentinel so the post-pass sets
|
||||
// the last active CU's w_hi = total_w correctly.
|
||||
const index_t total_w = prefix_batch[batch_size];
|
||||
for(index_t c = cu_lo; c < num_cus; ++c)
|
||||
cu_start_ibatch[c] = batch_size; // sentinel: this CU has no work
|
||||
{
|
||||
cu_states[c].w_lo = total_w;
|
||||
cu_states[c].w_hi = total_w;
|
||||
cu_states[c].ibatch = batch_size; // sentinel → early return on GPU
|
||||
cu_states[c].isplit = 0;
|
||||
cu_states[c].head_start = 0;
|
||||
cu_states[c].c_start = 0;
|
||||
}
|
||||
// Post-pass: set w_hi[c] = w_lo[c+1] (global start of next CU's first K-chunk).
|
||||
for(index_t c = 0; c < num_cus - 1; ++c)
|
||||
cu_states[c].w_hi = cu_states[c + 1].w_lo;
|
||||
cu_states[num_cus - 1].w_hi = total_w;
|
||||
|
||||
return sizeof(AccDataType) * dq_acc_elems;
|
||||
}
|
||||
@@ -525,8 +619,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t batch; // used for persistent kernel implementation
|
||||
const ck_tile::index_t* nsplits_ptr; // per-batch nsplits (group) or single scalar (batch)
|
||||
// group mode persistent scheduling tables (read from CPU workspace by GPU):
|
||||
const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1]
|
||||
const ck_tile::index_t* cu_start_ibatch_ptr; // first batch for each CU, size [num_cus]
|
||||
const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1]
|
||||
const FmhaBwdGroupPersistentCuState* cu_state_ptr; // per-CU packed state, size [num_cus]
|
||||
const FmhaBwdBatchState* batch_state_ptr; // per-batch sq/nc/nsplits, size [batch]
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeKargs
|
||||
@@ -952,8 +1047,13 @@ struct FmhaBwdDQDKDVKernel
|
||||
ws + WorkspaceManager::GetDqAccSplitsOffset(batch));
|
||||
kargs.prefix_batch_ptr = reinterpret_cast<const ck_tile::index_t*>(
|
||||
ws + WorkspaceManager::GetPrefixBatchOffset(batch));
|
||||
kargs.cu_start_ibatch_ptr = reinterpret_cast<const ck_tile::index_t*>(
|
||||
ws + WorkspaceManager::GetCuStartIbatchOffset(batch));
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
kargs.cu_state_ptr = reinterpret_cast<const FmhaBwdGroupPersistentCuState*>(
|
||||
ws + WorkspaceManager::GetCuStateOffset(batch));
|
||||
kargs.batch_state_ptr = reinterpret_cast<const FmhaBwdBatchState*>(
|
||||
ws + WorkspaceManager::GetBatchStateOffset(batch));
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
@@ -1010,6 +1110,14 @@ struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
static_assert(!kUseQrQtrDorPipeline,
|
||||
"Persistent kernel is not compatible with QR/QTR/DOR pipeline");
|
||||
|
||||
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
|
||||
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
|
||||
if constexpr(kHasMask == false)
|
||||
return x;
|
||||
else
|
||||
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
|
||||
};
|
||||
if constexpr(!kIsGroupMode)
|
||||
{
|
||||
// Batch mode persistent: uniform seqlen_k across all batches
|
||||
@@ -1027,14 +1135,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return; // worker_id exceeds total jobs, exit early
|
||||
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
|
||||
|
||||
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
|
||||
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
|
||||
if constexpr(kHasMask == false)
|
||||
return x;
|
||||
else
|
||||
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
|
||||
};
|
||||
|
||||
const auto n_splits = kargs.nsplits_ptr[0];
|
||||
index_t job_id = begin_job_id;
|
||||
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
|
||||
@@ -1054,68 +1154,60 @@ struct FmhaBwdDQDKDVKernel
|
||||
else
|
||||
{
|
||||
// Group mode persistent: variable seqlen per batch, dispatch via gist algo.
|
||||
// Each CU independently determines its workload interval using prefix_batch.
|
||||
const index_t cu_id = blockIdx.x;
|
||||
const index_t num_cu = gridDim.x;
|
||||
const index_t nbatch = kargs.batch;
|
||||
// Per-CU state (w_lo, w_hi, ibatch, isplit, head_start, c_start) is packed
|
||||
// in a single struct array to minimise kargs pointer count (5 → 1).
|
||||
// Remap block→CU: interleave SEs so consecutive blocks hit different SEs,
|
||||
// spreading dq_acc writes across HBM channels.
|
||||
const index_t cu_id = blockIdx.x / 8 + (blockIdx.x % 8) * 32;
|
||||
|
||||
// prefix_batch[nbatch] = total workload (nhead * sum(nc[b] * sq[b]))
|
||||
const index_t total_w = kargs.prefix_batch_ptr[nbatch];
|
||||
if(total_w == 0)
|
||||
return;
|
||||
const index_t target_w = integer_divide_ceil(total_w, num_cu);
|
||||
// Load all per-CU fields through a single pointer; pointer dies after loads.
|
||||
const FmhaBwdGroupPersistentCuState* cs = kargs.cu_state_ptr + cu_id;
|
||||
const index_t w_hi = amd_wave_read_first_lane(cs->w_hi);
|
||||
index_t ibatch = amd_wave_read_first_lane(cs->ibatch);
|
||||
index_t isplit = amd_wave_read_first_lane(cs->isplit);
|
||||
index_t head_start = amd_wave_read_first_lane(cs->head_start);
|
||||
index_t c_start_0 = amd_wave_read_first_lane(cs->c_start);
|
||||
index_t w_chunk = amd_wave_read_first_lane(cs->w_lo);
|
||||
if(ibatch >= kargs.batch)
|
||||
return; // this CU has no work (sentinel: ibatch == batch_size)
|
||||
|
||||
const index_t w_lo = amd_wave_read_first_lane(cu_id * target_w);
|
||||
const index_t w_hi = amd_wave_read_first_lane(
|
||||
min(static_cast<index_t>((cu_id + 1) * target_w), total_w));
|
||||
if(w_lo >= total_w)
|
||||
return; // this CU has no work
|
||||
|
||||
for(index_t ibatch = kargs.cu_start_ibatch_ptr[cu_id]; ibatch < nbatch;
|
||||
++ibatch)
|
||||
// w_chunk tracks the global K-chunk position; the loop exits when w_chunk
|
||||
// reaches w_hi (this CU's exclusive upper bound). ibatch < batch is guaranteed
|
||||
// on entry; the check inside guards against the rare case where head_start
|
||||
// reaches nhead_q and ibatch is incremented past batch before w_chunk catches
|
||||
// up.
|
||||
do
|
||||
{
|
||||
const index_t pb = kargs.prefix_batch_ptr[ibatch];
|
||||
if(pb >= w_hi)
|
||||
break; // all remaining batches are past this CU's interval
|
||||
if(ibatch >= kargs.batch)
|
||||
return;
|
||||
// sq/nc/nsplits are read inline (not pre-hoisted) to shorten their live
|
||||
// range across the inlined run_() body, reducing SGPR pressure.
|
||||
const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch;
|
||||
|
||||
// per-batch seqlen: prefer seqlen_ptr if available, else diff seqstart
|
||||
const index_t sq = amd_wave_read_first_lane(
|
||||
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[ibatch]
|
||||
: (kargs.seqstart_q_ptr[ibatch + 1] -
|
||||
kargs.seqstart_q_ptr[ibatch]));
|
||||
const index_t sk = amd_wave_read_first_lane(
|
||||
kargs.seqlen_k_ptr ? kargs.seqlen_k_ptr[ibatch]
|
||||
: (kargs.seqstart_k_ptr[ibatch + 1] -
|
||||
kargs.seqstart_k_ptr[ibatch]));
|
||||
const index_t nc = integer_divide_ceil(sk, FmhaPipeline::kN0);
|
||||
const index_t hw = nc * sq; // workload per (batch, head) pair
|
||||
if(hw == 0)
|
||||
continue;
|
||||
const index_t nsplits_b =
|
||||
amd_wave_read_first_lane(kargs.nsplits_ptr[ibatch]);
|
||||
|
||||
// first head whose interval overlaps [w_lo, w_hi)
|
||||
const index_t head_start = max(static_cast<index_t>((w_lo - pb) / hw), 0);
|
||||
|
||||
for(index_t head_idx = head_start; head_idx < kargs.nhead_q; ++head_idx)
|
||||
while(head_start < kargs.nhead_q)
|
||||
{
|
||||
const index_t w_head = pb + head_idx * hw;
|
||||
if(w_head >= w_hi)
|
||||
return; // remaining heads are past the interval
|
||||
|
||||
// wc_start: workload offset of this CU's start relative to head start.
|
||||
// Used for both isplit and c_start.
|
||||
const index_t wc_start = max(static_cast<index_t>(w_lo - w_head), 0);
|
||||
// isplit = rank of this CU among all CUs touching this head
|
||||
const index_t isplit = integer_divide_ceil(wc_start, target_w);
|
||||
const index_t c_start =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
|
||||
const index_t c_end = integer_divide_ceil(min(hw, w_hi - w_head), sq);
|
||||
|
||||
for(index_t chunk_idx = c_start; chunk_idx < c_end; ++chunk_idx)
|
||||
run_(kargs, dim3(chunk_idx, head_idx, ibatch), isplit, nsplits_b);
|
||||
while(c_start_0 < amd_wave_read_first_lane(bs->nc) && w_chunk < w_hi)
|
||||
{
|
||||
run_(kargs,
|
||||
dim3(tile_n_interleave(c_start_0,
|
||||
amd_wave_read_first_lane(bs->nc)),
|
||||
head_start,
|
||||
ibatch),
|
||||
isplit,
|
||||
amd_wave_read_first_lane(bs->nsplits));
|
||||
w_chunk += amd_wave_read_first_lane(bs->sq);
|
||||
++c_start_0;
|
||||
}
|
||||
if(w_chunk >= w_hi)
|
||||
return;
|
||||
// w_chunk is now at the start of the next head
|
||||
c_start_0 = 0;
|
||||
isplit = 0;
|
||||
++head_start;
|
||||
}
|
||||
}
|
||||
head_start = 0;
|
||||
++ibatch;
|
||||
} while(w_chunk < w_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user