Fixed impl

This commit is contained in:
Tianxing Wu
2025-11-13 14:02:24 +00:00
parent f4392ddaaf
commit db7224e067
2 changed files with 7 additions and 2 deletions

View File

@@ -58,8 +58,13 @@ struct unified_attention_kernel_traits
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
// TODO please fix this
static constexpr index_t BLOCK_M = 256;
static constexpr index_t num_head_per_kv = 4;
static constexpr index_t BLOCK_Q = BLOCK_M / num_head_per_kv;
// BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<256, 64, 128, 128>;
using unified_attention_block_tile = sequence<BLOCK_M, BLOCK_Q, 32, 128>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
using unified_attention_block_warps = sequence<8, 1, 1>;

View File

@@ -167,7 +167,7 @@ struct UnifiedAttentionKernel
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
ck_tile::index_t total_num_q_blocks)
{
return dim3(num_kv_heads * total_num_q_blocks, 0, 0);
return dim3(num_kv_heads * total_num_q_blocks);
}
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,