pipeline api

This commit is contained in:
Tianxing Wu
2025-10-14 09:58:27 +00:00
parent 6a7fa959b7
commit e1120fffb0
2 changed files with 20 additions and 1 deletions

View File

@@ -96,7 +96,6 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_seqs; // number of batches for q
};
using Kargs = UnifiedAttentionVarlenKargs;
CK_TILE_HOST static constexpr Kargs MakeKargs(
@@ -332,6 +331,8 @@ struct FmhaFwdV3Kernel
index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start
index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
index_t block_table_offset = seq_idx * kargs.block_table_stride;
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
@@ -445,10 +446,20 @@ struct FmhaFwdV3Kernel
auto v_dram_window = make_tile_window(
v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0});
// Create mask for causal attention
auto mask = [&]() {
return make_casual_mask(query_pos, BLOCK_Q, max_seq_prefix_len, BLOCK_SIZE);
}();
// Define LSE dram window (or use a dummy if not needed by pipeline)
auto lse_dram_window = make_dummy_tile_window();
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
block_tables_ptr,
block_table_offset,
lse_dram_window,
mask,
kargs.scale_s,

View File

@@ -391,6 +391,8 @@ struct UnifiedAttentionPipeline
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
[[maybe_unused]] const VElementFunction& v_element_func,
const void* block_tables_ptr,
index_t block_table_offset,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
@@ -402,6 +404,8 @@ struct UnifiedAttentionPipeline
{
using namespace ck_tile;
index_t block_idx_prev = 0;
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
@@ -1231,6 +1235,8 @@ struct UnifiedAttentionPipeline
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const void* block_tables_ptr,
index_t block_table_offset,
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
@@ -1244,6 +1250,8 @@ struct UnifiedAttentionPipeline
identity{},
v_dram_block_window_tmp,
identity{},
block_tables_ptr,
block_table_offset,
lse_dram_block_window_tmp,
identity{},
identity{},