fixing bugs

This commit is contained in:
Juuso Korhonen
2025-10-23 11:40:48 +00:00
parent 3bcef59536
commit 5bf72d2bcb
2 changed files with 23 additions and 4 deletions

View File

@@ -460,8 +460,8 @@ struct UnifiedAttentionKernel
cur_batch_query_len, // x (i.e. extend)
seq_len, // y_total (x + y)
cur_batch_query_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile
);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();
@@ -470,6 +470,7 @@ struct UnifiedAttentionKernel
return UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
num_queries_per_kv,
kargs.block_tables_ptr,
block_table_offset,
mask,

View File

@@ -278,8 +278,8 @@ struct UnifiedAttentionPipeline
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDim;
// static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
@@ -1208,6 +1208,7 @@ struct UnifiedAttentionPipeline
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
index_t num_queries_per_kv,
const void* block_tables_ptr,
index_t block_table_offset,
FmhaMask mask,
@@ -1216,12 +1217,29 @@ struct UnifiedAttentionPipeline
{
using namespace ck_tile;
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
[[maybe_unused]] const VElementFunction& v_element_func,
index_t num_queries_per_kv,
const void* block_tables_ptr,
index_t block_table_offset,
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale_s,
void* smem_ptr) const
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
num_queries_per_kv,
block_tables_ptr,
block_table_offset,
identity{},