fix test args error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
This commit is contained in:
Linjun-AMD
2025-12-29 21:25:58 -06:00
parent 4d73213f37
commit 5ab683b02b
2 changed files with 21 additions and 7 deletions

View File

@@ -196,7 +196,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -282,8 +283,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0)
{
set_tile(m, sink_v);
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
@@ -887,7 +896,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -913,7 +923,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
page_idx,
stride_k,
stride_v,
dropout);
dropout,
sink_v);
}
};

View File

@@ -120,8 +120,8 @@ const ck_tile::stream_config stream_config{
1, // rotating_count_
};
#define COMMON_ARGS \
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
#define COMMON_ARGS \
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \
stream_config
auto EnableTestIf(bool condition)
@@ -255,6 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
1, //init_sink
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
@@ -299,6 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
1, //init_sink
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
@@ -342,6 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
1, //init_sink
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}