mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Fix potential bug in kernel host interface BlockSize()
This commit is contained in:
@@ -173,7 +173,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0);
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.seqlen_q, param.hdim_v, has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
|
||||
@@ -236,7 +236,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
@@ -260,8 +260,8 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q);
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
|
||||
@@ -608,7 +608,17 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
// it looks get_warp_size() always return 64 when called from host, so
|
||||
// halfing is needed to get actual BlockSize
|
||||
return dim3(kBlockSize / get_warp_size() * 32);
|
||||
}
|
||||
else
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -625,7 +625,17 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
return ck_tile::make_tuple(my_seqlen_k_start, my_seqlen_k_end);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
// it looks get_warp_size() always return 64 when called from host, so
|
||||
// halfing is needed to get actual BlockSize
|
||||
return dim3(kBlockSize / get_warp_size() * 32);
|
||||
}
|
||||
else
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -159,7 +159,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
|
||||
@@ -221,7 +221,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
@@ -244,8 +244,8 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
|
||||
@@ -162,7 +162,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.max_seqlen_q,
|
||||
param.hdim_v,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
|
||||
@@ -224,7 +224,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
@@ -247,8 +247,8 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
|
||||
Reference in New Issue
Block a user