mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Various fixes
This commit is contained in:
@@ -184,7 +184,7 @@ struct UnifiedAttentionKernel
|
||||
while(left < right)
|
||||
{
|
||||
ck_tile::index_t mid = (left + right) / 2;
|
||||
ck_tile::index_t val = query_start_len_ptr[mid];
|
||||
ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]);
|
||||
ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
|
||||
|
||||
if(mid_val <= target_idx)
|
||||
@@ -206,7 +206,7 @@ struct UnifiedAttentionKernel
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t NUM_XCDS = 8;
|
||||
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q);
|
||||
const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv);
|
||||
|
||||
// Number of pids per XCD in the new arrangement
|
||||
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
|
||||
@@ -245,10 +245,7 @@ struct UnifiedAttentionKernel
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// UnifiedAttentionPipeline::kN1);
|
||||
|
||||
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
|
||||
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n);
|
||||
return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -277,7 +274,6 @@ struct UnifiedAttentionKernel
|
||||
// const index_t num_head_q = kargs.num_head_q;
|
||||
|
||||
// const index_t num_head_k = num_head_q / num_queries_per_kv;
|
||||
|
||||
pid = RemapTileIndices(pid, kargs);
|
||||
|
||||
// divide problem
|
||||
@@ -295,19 +291,15 @@ struct UnifiedAttentionKernel
|
||||
BLOCK_Q,
|
||||
true); // which batch
|
||||
|
||||
const index_t q_block_start_idx =
|
||||
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx;
|
||||
|
||||
const index_t q_block_local_idx =
|
||||
amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
|
||||
const index_t cur_batch_in_all_start_index =
|
||||
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t cur_batch_in_all_stop_index =
|
||||
amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]);
|
||||
const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx];
|
||||
const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1];
|
||||
|
||||
const index_t cur_batch_query_len =
|
||||
cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
|
||||
amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index);
|
||||
|
||||
// TODO check if we get the block size info from pipeline
|
||||
if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len)
|
||||
@@ -315,14 +307,14 @@ struct UnifiedAttentionKernel
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t query_pos = q_block_local_idx * BLOCK_Q;
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = seq_len - cur_batch_query_len;
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len =
|
||||
(context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
|
||||
+ 1);
|
||||
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
|
||||
+ 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
@@ -330,7 +322,7 @@ struct UnifiedAttentionKernel
|
||||
}
|
||||
|
||||
const auto max_seq_prefix_len = _max_seq_prefix_len;
|
||||
const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
||||
|
||||
// TODO sliding window
|
||||
const index_t num_blocks_start = 0;
|
||||
@@ -357,7 +349,7 @@ struct UnifiedAttentionKernel
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
|
||||
|
||||
index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q;
|
||||
index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
@@ -367,20 +359,20 @@ struct UnifiedAttentionKernel
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
|
||||
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
|
||||
number<UnifiedAttentionPipeline::kAlignmentQ>{},
|
||||
number<2>{});
|
||||
number<1>{});
|
||||
|
||||
const auto q_dram_pad =
|
||||
pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
|
||||
q_dram_base,
|
||||
// block sizes
|
||||
make_tuple(number<BLOCK_Q>{}, number<1>{}, number<HEAD_SIZE_PADDED>{}),
|
||||
make_tuple(number<BLOCK_Q>{}, 1, HEAD_SIZE_PADDED),
|
||||
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
|
||||
// HEAD_SIZE_PADDED)
|
||||
|
||||
const auto q_dram_merged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(number<HEAD_SIZE_PADDED>{})),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{})); // flattens the first two dims, head idx is the fastest
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
#define ENABLE_ASM_MARKER 1
|
||||
#if ENABLE_ASM_MARKER
|
||||
#define ASM_MARKER(marker) \
|
||||
|
||||
Reference in New Issue
Block a user