Adjust the codes related to calculate i_m0 in the kernel

This commit is contained in:
Qianfeng Zhang
2025-07-23 13:23:11 +00:00
parent f49fe28ca2
commit cf012c23fc

View File

@@ -651,23 +651,31 @@ struct HstuAttentionFwdKernel
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
index_t seqlen_in_first_split = kargs.seqlen;
bool is_tile_in_first_split = true;
index_t i_m0;
if constexpr(kHasLocalMask)
{
if(kargs.min_full_attn_seqlen > 0)
{
seqlen_in_first_split = kargs.seqlen - kargs.min_full_attn_seqlen - num_target;
};
index_t num_tile_in_first_split =
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
index_t num_tile_in_first_split =
ck_tile::integer_divide_ceil(seqlen_in_first_split, HstuAttentionPipeline::kM0);
bool is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
index_t i_m0 = is_tile_in_first_split
i_m0 = is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
}
else
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
}
else
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);