From 6981f148eef1e925d18fa4ac12d6073f31a23706 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 8 May 2026 06:27:19 -0400 Subject: [PATCH] Fix potential bug in kernel host interface BlockSize() --- .../hstu_attention_batched_forward_dispatch.hpp | 2 +- ...tu_attention_batched_forward_splitkv_dispatch.hpp | 6 +++--- .../18_hstu_attention/hstu_attention_fwd_kernel.hpp | 12 +++++++++++- .../hstu_attention_fwd_splitkv_kernel.hpp | 12 +++++++++++- .../hstu_attention_group_forward_dispatch.hpp | 2 +- ...hstu_attention_group_forward_splitkv_dispatch.hpp | 6 +++--- .../hstu_attention_jagged_forward_dispatch.hpp | 2 +- ...stu_attention_jagged_forward_splitkv_dispatch.hpp | 6 +++--- 8 files changed, 34 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 814c3af0d1..32d71a005a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -173,7 +173,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); dim3 kGridSize = HstuKernel::GridSize( param.num_batch, param.num_head, param.seqlen_q, param.hdim_v, has_minfull_attn_seqlen); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp index fb7b005aaf..e30d25885e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp @@ -236,7 +236,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v, param.num_splits, has_minfull_attn_seqlen); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( @@ -260,8 +260,8 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v); }(); - dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 4bcf37d163..ef689b5666 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -608,7 +608,17 @@ struct HstuAttentionFwdKernel } } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() + { + if(is_wave32()) + { + // it looks get_warp_size() always return 64 when called from host, so + // halfing is needed to get actual BlockSize + return dim3(kBlockSize / get_warp_size() * 32); + } + else + return dim3(kBlockSize); + } CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index 58111827c0..7f702e36e8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -625,7 +625,17 @@ struct HstuAttentionFwdSplitKVKernel return ck_tile::make_tuple(my_seqlen_k_start, my_seqlen_k_end); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() + { + if(is_wave32()) + { + // it looks get_warp_size() always return 64 when called from host, so + // halfing is needed to get actual BlockSize + return dim3(kBlockSize / get_warp_size() * 32); + } + else + return dim3(kBlockSize); + } CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp index 9e8df40392..4f58a8ec38 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp @@ -159,7 +159,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp index cd0591cbfd..6378d22c11 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp @@ -221,7 +221,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch dim3 kGridSize = HstuKernel::GridSize( param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( @@ -244,8 +244,8 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v); }(); - dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 052eb8e19e..a03ef80985 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -162,7 +162,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch param.max_seqlen_q, param.hdim_v, has_minfull_attn_seqlen); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp index 5f6444d31a..b72b6d7d85 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp @@ -224,7 +224,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v, param.num_splits, has_minfull_attn_seqlen); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel( @@ -247,8 +247,8 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch param.hdim_v); }(); - dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); - constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); + dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; (void)ck_tile::launch_kernel(