mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Move the lambda for dividing by max_seqlen from kernel to pipeline
This commit is contained in:
@@ -35,8 +35,6 @@ struct HstuAttentionFwdKernel
|
||||
using QKVDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::QKVDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::BiasDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
|
||||
using GemmAccDataType =
|
||||
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::GemmAccDataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::VLayout>;
|
||||
|
||||
@@ -742,6 +740,7 @@ struct HstuAttentionFwdKernel
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.max_seqlen,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}();
|
||||
@@ -766,13 +765,6 @@ struct HstuAttentionFwdKernel
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{}, number<HstuAttentionPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x) {
|
||||
x = x * type_convert<GemmAccDataType>(
|
||||
__builtin_amdgcn_rcpf(static_cast<float>(kargs.max_seqlen)));
|
||||
},
|
||||
o_acc_tile);
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -133,6 +133,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
index_t max_seqlen,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -463,6 +464,13 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
} while(++i_loop < num_loops);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x) {
|
||||
x = x * type_convert<GemmAccDataType>(
|
||||
__builtin_amdgcn_rcpf(static_cast<float>(max_seqlen)));
|
||||
},
|
||||
o_acc);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
@@ -479,6 +487,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
int max_seqlen,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -495,6 +504,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
max_seqlen,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user