Move the lambda for dividing by max_seqlen from kernel to pipeline

This commit is contained in:
Qianfeng Zhang
2025-05-18 07:56:34 +00:00
parent 0771390a28
commit 58e45ec53a
2 changed files with 11 additions and 9 deletions

View File

@@ -35,8 +35,6 @@ struct HstuAttentionFwdKernel
using QKVDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::QKVDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::BiasDataType>;
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
using GemmAccDataType =
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::GemmAccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::VLayout>;
@@ -742,6 +740,7 @@ struct HstuAttentionFwdKernel
bias_dram_window,
mask,
kargs.scale_s,
kargs.max_seqlen,
smem_ptr,
dropout);
}();
@@ -766,13 +765,6 @@ struct HstuAttentionFwdKernel
make_tuple(number<HstuAttentionPipeline::kM0>{}, number<HstuAttentionPipeline::kN1>{}),
{i_m0, i_n1});
tile_elementwise_inout(
[&](auto& x) {
x = x * type_convert<GemmAccDataType>(
__builtin_amdgcn_rcpf(static_cast<float>(kargs.max_seqlen)));
},
o_acc_tile);
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}
};

View File

@@ -133,6 +133,7 @@ struct HstuAttentionFwdPipelineQRKSVS
const OAccElementFunction& o_acc_element_func,
HstuMask mask,
float scale_s,
index_t max_seqlen,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -463,6 +464,13 @@ struct HstuAttentionFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
} while(++i_loop < num_loops);
tile_elementwise_inout(
[&](auto& x) {
x = x * type_convert<GemmAccDataType>(
__builtin_amdgcn_rcpf(static_cast<float>(max_seqlen)));
},
o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
@@ -479,6 +487,7 @@ struct HstuAttentionFwdPipelineQRKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,
float scale_s,
int max_seqlen,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -495,6 +504,7 @@ struct HstuAttentionFwdPipelineQRKSVS
identity{},
mask,
scale_s,
max_seqlen,
smem_ptr,
dropout);
}