mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fixed impl
This commit is contained in:
@@ -58,8 +58,13 @@ struct unified_attention_kernel_traits
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// TODO please fix this
|
||||
static constexpr index_t BLOCK_M = 256;
|
||||
static constexpr index_t num_head_per_kv = 4;
|
||||
static constexpr index_t BLOCK_Q = BLOCK_M / num_head_per_kv;
|
||||
|
||||
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<256, 64, 128, 128>;
|
||||
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, 32, 128>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ struct UnifiedAttentionKernel
|
||||
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
|
||||
ck_tile::index_t total_num_q_blocks)
|
||||
{
|
||||
return dim3(num_kv_heads * total_num_q_blocks, 0, 0);
|
||||
return dim3(num_kv_heads * total_num_q_blocks);
|
||||
}
|
||||
|
||||
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
|
||||
|
||||
Reference in New Issue
Block a user