mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
Some renaming in kernel and pipeline
This commit is contained in:
@@ -329,11 +329,11 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
}
|
||||
}
|
||||
|
||||
index_t i_m0;
|
||||
index_t i_m;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM);
|
||||
i_m = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM);
|
||||
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
if(kargs.seqlen_q <= i_m)
|
||||
return;
|
||||
|
||||
// assume o_acc is in compact shape of [batch_size, seqlen, num_head, num_splits, hdim]
|
||||
@@ -364,7 +364,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
make_tile_window(o_acc_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kOHeaddim>{}),
|
||||
{i_m0, 0});
|
||||
{i_m, 0});
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kUseSoftmax)
|
||||
@@ -402,7 +402,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
make_tile_window(lse_acc_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kMaxSplits>{}),
|
||||
{i_m0, 0});
|
||||
{i_m, 0});
|
||||
|
||||
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto lse_dram_window_lengths =
|
||||
@@ -427,7 +427,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
lse_dram_naive, lse_dram_window_lengths, sequence<false>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
|
||||
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -467,7 +467,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{},
|
||||
number<HstuAttentionPipeline::kOHeaddim>{}),
|
||||
{i_m0, 0});
|
||||
{i_m, 0});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
@@ -134,7 +134,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
const LSEaccElementFunction& lse_or_lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
@@ -205,7 +205,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
|
||||
{
|
||||
auto lse_or_lse_acc =
|
||||
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
@@ -606,7 +606,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// if pipeline is called from splitkv_kernel, the window shall not be null;
|
||||
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
|
||||
{
|
||||
// store lse or lse_acc
|
||||
auto lse_or_lse_acc =
|
||||
@@ -647,14 +647,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
|
||||
Reference in New Issue
Block a user