From db7224e0671dbb3f83e063ec0ab92d5a88f3bf6a Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 13 Nov 2025 14:02:24 +0000 Subject: [PATCH] Fixed impl --- .../01_unified_attention/unified_attention_impl.hpp | 7 ++++++- .../unified_attention/kernel/unified_attention_kernel.hpp | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index dc3104e4f2..995ea8fdf8 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,8 +58,13 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; + // TODO please fix this + static constexpr index_t BLOCK_M = 256; + static constexpr index_t num_head_per_kv = 4; + static constexpr index_t BLOCK_Q = BLOCK_M / num_head_per_kv; + // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE - using unified_attention_block_tile = sequence<256, 64, 128, 128>; + using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ac7a06a961..1cf7698b61 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -167,7 +167,7 @@ struct UnifiedAttentionKernel CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, ck_tile::index_t total_num_q_blocks) { - return dim3(num_kv_heads * total_num_q_blocks, 0, 0); + return dim3(num_kv_heads * total_num_q_blocks); } // CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,