Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention

This commit is contained in:
Juuso Korhonen
2025-10-13 10:21:27 +00:00

View File

@@ -59,9 +59,7 @@ struct FmhaFwdV3Kernel
const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
void* o_ptr;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_blks;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
@@ -89,7 +87,7 @@ struct FmhaFwdV3Kernel
};
struct UnifiedAttentionVarlenKargs
struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs
{
const int32_t* block_tables_ptr;
const int32_t* seq_lens_ptr; // seq len in each batch
@@ -98,20 +96,15 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_seqs; // number of batches for q
};
struct Kargs {
UnifiedAttentionCommonKargs unifiedAttentionCommonKargs;
UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs;
};
// using Kargs = FmhaFwdGroupModeKargs;
using Kargs = UnifiedAttentionVarlenKargs;
CK_TILE_HOST static constexpr Kargs MakeKargs(
const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_blks,
ck_tile::index_t num_head_q,
const ck_tile::index_t num_queries_per_kv,
float scale_s,
@@ -135,15 +128,14 @@ struct FmhaFwdV3Kernel
const int32_t* block_tables_ptr,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs,
ck_tile::index_t num_seqs
)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
hdim_q,
hdim_v,
num_blks,
num_head_q,
num_queries_per_kv,
static_cast<float>(scale_s * ck_tile::log2e_v<>),
@@ -221,10 +213,10 @@ struct FmhaFwdV3Kernel
RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs)
{
using namespace ck_tile;
constexpr index_t NUM_XCDS = 8;
const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks *
(kargs.unifiedAttentionCommonKargs.num_head_q);
const index_t GRID_MN = kargs.total_num_q_blocks *
(kargs.num_head_q);
// Number of pids per XCD in the new arrangement
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
@@ -259,7 +251,7 @@ struct FmhaFwdV3Kernel
{
using namespace ck_tile;
ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks;
ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// FmhaPipeline::kN1);
@@ -282,9 +274,12 @@ struct FmhaFwdV3Kernel
__shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t pid = blockIdx.x;
index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv;
const index_t BLOCK_M = BLOCK_Q * kargs.unifiedAttentionCommonKargs.num_queries_per_kv;
const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv;
// for simplicity, batch stride we just modify the pointer
const index_t num_head_q = kargs.num_head_q;
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
const index_t num_head_k = num_head_q / num_queries_per_kv;
pid = RemapTileIndices(pid, kargs);
@@ -297,15 +292,15 @@ struct FmhaFwdV3Kernel
// one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token
const index_t seq_idx = find_seq_idx(
kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true
kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
); // which batch
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]);
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
@@ -315,7 +310,7 @@ struct FmhaFwdV3Kernel
}
const index_t query_pos = q_block_local_idx * BLOCK_Q;
const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; // should be cu_seqlens_q rather
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = seq_len - cur_batch_query_len;
@@ -326,10 +321,6 @@ struct FmhaFwdV3Kernel
+ 1
);
// for simplicity, batch stride we just modify the pointer
index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q;
index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv;
// Q/K/V DRAM and DRAM window
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start
@@ -371,6 +362,7 @@ struct FmhaFwdV3Kernel
make_tuple(sequence<0>{}, sequence<1>{})
);
// TODO are we padding the tensor view or the block here?
return q_dram_merged;
}();
@@ -385,7 +377,7 @@ struct FmhaFwdV3Kernel
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(num_b, BLOCK_SIZE, num_head_k, HEAD_SIZE),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});