mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
fixing bugs
This commit is contained in:
@@ -460,8 +460,8 @@ struct UnifiedAttentionKernel
|
||||
cur_batch_query_len, // x (i.e. extend)
|
||||
seq_len, // y_total (x + y)
|
||||
cur_batch_query_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile
|
||||
);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
@@ -470,6 +470,7 @@ struct UnifiedAttentionKernel
|
||||
return UnifiedAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
num_queries_per_kv,
|
||||
kargs.block_tables_ptr,
|
||||
block_table_offset,
|
||||
mask,
|
||||
|
||||
@@ -278,8 +278,8 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDim;
|
||||
// static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -1208,6 +1208,7 @@ struct UnifiedAttentionPipeline
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
index_t num_queries_per_kv,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
FmhaMask mask,
|
||||
@@ -1216,12 +1217,29 @@ struct UnifiedAttentionPipeline
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
[[maybe_unused]] const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
[[maybe_unused]] const VElementFunction& v_element_func,
|
||||
index_t num_queries_per_kv,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
num_queries_per_kv,
|
||||
block_tables_ptr,
|
||||
block_table_offset,
|
||||
identity{},
|
||||
|
||||
Reference in New Issue
Block a user