mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
@@ -196,7 +196,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -282,8 +283,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0)
|
||||
{
|
||||
set_tile(m, sink_v);
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -887,7 +896,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -913,7 +923,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user