diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index f83209a2c4..f9a7f1476e 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,7 +58,7 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - // BLOCK_Q BLOCK_SIZE HEAD_SIZE + // BLOCK_M BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence<128, 128, 128>; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; @@ -109,6 +109,7 @@ struct unified_attention_kernel_traits template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index d0704626e9..914619dc5a 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -47,7 +47,7 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t BLOCK_Q = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along q seqlen // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3dee9b4ad8..a8734ac0ab 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -268,7 +268,8 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; - static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; + static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; + static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; @@ -302,12 +303,12 @@ struct UnifiedAttentionPipeline } }(); - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize(index_t num_queries_per_kv) + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { // create another LDS buffer for p - return ck_tile::max(BLOCK_Q * num_queries_per_kv * HEAD_SIZE_PADDED * sizeof(PDataType), + return ck_tile::max(BLOCK_M * HEAD_SIZE_PADDED * sizeof(PDataType), Policy::template GetSmemSize() + - BLOCK_Q * num_queries_per_kv * BLOCK_SIZE * sizeof(PDataType)); + BLOCK_M * BLOCK_SIZE * sizeof(PDataType)); } // for debug only @@ -394,7 +395,7 @@ struct UnifiedAttentionPipeline void* smem_ptr) const { using namespace ck_tile; - + index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; static_assert( std::is_same_v> && @@ -409,7 +410,7 @@ struct UnifiedAttentionPipeline BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize(num_queries_per_kv)); + static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc());