mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention
This commit is contained in:
@@ -59,9 +59,7 @@ struct FmhaFwdV3Kernel
|
||||
const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_blks;
|
||||
ck_tile::index_t num_head_q;
|
||||
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
@@ -89,7 +87,7 @@ struct FmhaFwdV3Kernel
|
||||
};
|
||||
|
||||
|
||||
struct UnifiedAttentionVarlenKargs
|
||||
struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs
|
||||
{
|
||||
const int32_t* block_tables_ptr;
|
||||
const int32_t* seq_lens_ptr; // seq len in each batch
|
||||
@@ -98,20 +96,15 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
};
|
||||
|
||||
struct Kargs {
|
||||
UnifiedAttentionCommonKargs unifiedAttentionCommonKargs;
|
||||
UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs;
|
||||
};
|
||||
|
||||
// using Kargs = FmhaFwdGroupModeKargs;
|
||||
using Kargs = UnifiedAttentionVarlenKargs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(
|
||||
const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_blks,
|
||||
ck_tile::index_t num_head_q,
|
||||
const ck_tile::index_t num_queries_per_kv,
|
||||
float scale_s,
|
||||
@@ -135,15 +128,14 @@ struct FmhaFwdV3Kernel
|
||||
const int32_t* block_tables_ptr,
|
||||
const int32_t* seq_lens_ptr,
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t num_seqs
|
||||
)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_blks,
|
||||
num_head_q,
|
||||
num_queries_per_kv,
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
@@ -221,10 +213,10 @@ struct FmhaFwdV3Kernel
|
||||
RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
|
||||
constexpr index_t NUM_XCDS = 8;
|
||||
const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks *
|
||||
(kargs.unifiedAttentionCommonKargs.num_head_q);
|
||||
const index_t GRID_MN = kargs.total_num_q_blocks *
|
||||
(kargs.num_head_q);
|
||||
|
||||
// Number of pids per XCD in the new arrangement
|
||||
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
|
||||
@@ -259,7 +251,7 @@ struct FmhaFwdV3Kernel
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks;
|
||||
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// FmhaPipeline::kN1);
|
||||
|
||||
@@ -282,9 +274,12 @@ struct FmhaFwdV3Kernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
ck_tile::index_t pid = blockIdx.x;
|
||||
index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv;
|
||||
|
||||
const index_t BLOCK_M = BLOCK_Q * kargs.unifiedAttentionCommonKargs.num_queries_per_kv;
|
||||
const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv;
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const index_t num_head_q = kargs.num_head_q;
|
||||
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
|
||||
const index_t num_head_k = num_head_q / num_queries_per_kv;
|
||||
|
||||
pid = RemapTileIndices(pid, kargs);
|
||||
|
||||
@@ -297,15 +292,15 @@ struct FmhaFwdV3Kernel
|
||||
// one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token
|
||||
|
||||
const index_t seq_idx = find_seq_idx(
|
||||
kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true
|
||||
kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
|
||||
); // which batch
|
||||
|
||||
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
|
||||
|
||||
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
|
||||
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]);
|
||||
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
|
||||
|
||||
const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
|
||||
|
||||
@@ -315,7 +310,7 @@ struct FmhaFwdV3Kernel
|
||||
}
|
||||
|
||||
const index_t query_pos = q_block_local_idx * BLOCK_Q;
|
||||
const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; // should be cu_seqlens_q rather
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = seq_len - cur_batch_query_len;
|
||||
|
||||
@@ -326,10 +321,6 @@ struct FmhaFwdV3Kernel
|
||||
+ 1
|
||||
);
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q;
|
||||
index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start
|
||||
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start
|
||||
@@ -371,6 +362,7 @@ struct FmhaFwdV3Kernel
|
||||
make_tuple(sequence<0>{}, sequence<1>{})
|
||||
);
|
||||
|
||||
// TODO are we padding the tensor view or the block here?
|
||||
return q_dram_merged;
|
||||
}();
|
||||
|
||||
@@ -385,7 +377,7 @@ struct FmhaFwdV3Kernel
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(num_b, BLOCK_SIZE, num_head_k, HEAD_SIZE),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
Reference in New Issue
Block a user