mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
pipeline api
This commit is contained in:
@@ -96,7 +96,6 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
};
|
||||
|
||||
|
||||
using Kargs = UnifiedAttentionVarlenKargs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(
|
||||
@@ -332,6 +331,8 @@ struct FmhaFwdV3Kernel
|
||||
index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start
|
||||
index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start
|
||||
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
|
||||
index_t block_table_offset = seq_idx * kargs.block_table_stride;
|
||||
|
||||
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
|
||||
@@ -445,10 +446,20 @@ struct FmhaFwdV3Kernel
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0});
|
||||
|
||||
// Create mask for causal attention
|
||||
auto mask = [&]() {
|
||||
return make_casual_mask(query_pos, BLOCK_Q, max_seq_prefix_len, BLOCK_SIZE);
|
||||
}();
|
||||
|
||||
// Define LSE dram window (or use a dummy if not needed by pipeline)
|
||||
auto lse_dram_window = make_dummy_tile_window();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
block_tables_ptr,
|
||||
block_table_offset,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
|
||||
@@ -391,6 +391,8 @@ struct UnifiedAttentionPipeline
|
||||
[[maybe_unused]] const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
[[maybe_unused]] const VElementFunction& v_element_func,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
@@ -402,6 +404,8 @@ struct UnifiedAttentionPipeline
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
index_t block_idx_prev = 0;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
@@ -1231,6 +1235,8 @@ struct UnifiedAttentionPipeline
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
@@ -1244,6 +1250,8 @@ struct UnifiedAttentionPipeline
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
block_tables_ptr,
|
||||
block_table_offset,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
Reference in New Issue
Block a user