mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:12:19 +00:00
@@ -196,7 +196,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -282,8 +283,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(!__builtin_isinf_sign(sink_v) || __builtin_isinf_sign(sink_v) > 0)
|
||||
{
|
||||
set_tile(m, sink_v);
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -887,7 +896,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -913,7 +923,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -120,8 +120,8 @@ const ck_tile::stream_config stream_config{
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
@@ -255,6 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
0,
|
||||
1, //init_sink
|
||||
stream_config);
|
||||
ASSERT_EQ(result, fwd_result::invalid_args);
|
||||
}
|
||||
@@ -299,6 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
0,
|
||||
1, //init_sink
|
||||
stream_config);
|
||||
ASSERT_EQ(result, fwd_result::invalid_args);
|
||||
}
|
||||
@@ -342,6 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
0,
|
||||
1, //init_sink
|
||||
stream_config);
|
||||
ASSERT_EQ(result, fwd_result::invalid_args);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user