Some renaming in kernel and pipeline

This commit is contained in:
Qianfeng Zhang
2026-06-05 15:43:47 +00:00
parent 42a3bfbab7
commit f73341de37
2 changed files with 17 additions and 17 deletions

View File

@@ -329,11 +329,11 @@ struct HstuAttentionFwdSplitKVCombineKernel
}
}
index_t i_m0;
index_t i_m;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM);
i_m = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM);
if(kargs.seqlen_q <= i_m0)
if(kargs.seqlen_q <= i_m)
return;
// assume o_acc is in compact shape of [batch_size, seqlen, num_head, num_splits, hdim]
@@ -364,7 +364,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
make_tile_window(o_acc_dram,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kOHeaddim>{}),
{i_m0, 0});
{i_m, 0});
auto o_acc_tile = [&]() {
if constexpr(kUseSoftmax)
@@ -402,7 +402,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
make_tile_window(lse_acc_dram,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kMaxSplits>{}),
{i_m0, 0});
{i_m, 0});
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths =
@@ -427,7 +427,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
lse_dram_naive, lse_dram_window_lengths, sequence<false>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m});
}
else
{
@@ -467,7 +467,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
make_tile_window(o_dram,
make_tuple(number<HstuAttentionPipeline::kM>{},
number<HstuAttentionPipeline::kOHeaddim>{}),
{i_m0, 0});
{i_m, 0});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}

View File

@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename QElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
@@ -134,7 +134,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
const LSEaccElementFunction& lse_or_lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
@@ -205,7 +205,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
{
auto lse_or_lse_acc =
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
@@ -606,7 +606,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// if pipeline is called from splitkv_kernel, the window shall not be null;
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
{
// store lse or lse_acc
auto lse_or_lse_acc =
@@ -647,14 +647,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,