Fix potential bug in kernel host interface BlockSize()

This commit is contained in:
Qianfeng Zhang
2026-05-08 06:27:19 -04:00
parent 250f325c3a
commit 6981f148ee
8 changed files with 34 additions and 14 deletions

View File

@@ -173,7 +173,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.seqlen_q, param.hdim_v, has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(

View File

@@ -236,7 +236,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v,
param.num_splits,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
@@ -260,8 +260,8 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q);
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(

View File

@@ -608,7 +608,17 @@ struct HstuAttentionFwdKernel
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize()
{
if(is_wave32())
{
// it looks get_warp_size() always return 64 when called from host, so
// halfing is needed to get actual BlockSize
return dim3(kBlockSize / get_warp_size() * 32);
}
else
return dim3(kBlockSize);
}
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{

View File

@@ -625,7 +625,17 @@ struct HstuAttentionFwdSplitKVKernel
return ck_tile::make_tuple(my_seqlen_k_start, my_seqlen_k_end);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize()
{
if(is_wave32())
{
// it looks get_warp_size() always return 64 when called from host, so
// halfing is needed to get actual BlockSize
return dim3(kBlockSize / get_warp_size() * 32);
}
else
return dim3(kBlockSize);
}
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{

View File

@@ -159,7 +159,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
dim3 kGridSize =
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(

View File

@@ -221,7 +221,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
@@ -244,8 +244,8 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(

View File

@@ -162,7 +162,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
param.max_seqlen_q,
param.hdim_v,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(

View File

@@ -224,7 +224,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v,
param.num_splits,
has_minfull_attn_seqlen);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(
@@ -247,8 +247,8 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.hdim_v);
}();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(