fix test args error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
This commit is contained in:
Linjun-AMD
2025-12-29 21:25:58 -06:00
parent 4d73213f37
commit 5ab683b02b
2 changed files with 21 additions and 7 deletions

View File

@@ -196,7 +196,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -282,8 +283,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0)
{
set_tile(m, sink_v);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
@@ -887,7 +896,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -913,7 +923,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
page_idx,
stride_k,
stride_v,
dropout);
dropout,
sink_v);
}
};