[CK_TILE][FMHA] Enable gpt-oss sink (#3490)

* Enable gptoss sink

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add gptoss sink test

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update CHANGELOG.md

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix test args error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update test_fmha_fwd.cpp

* update sink test

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Revert "update sink test"

This reverts commit 970b4f1686.

* update sink test

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update valid sink_v in splitkv pipeline

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update example_fmha_fwd.cpp

* fix lse error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix clangformat error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix aiter scale error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update block_fmha_pipeline_qr_ks_vs.hpp

* div scale_s for sink_value

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update fmha_fwd_runner.hpp

* update sink_value with bias

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Fix typo in dropout parameter in fmha_batch_prefill_kernel

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update example_fmha_fwd.cpp

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* optimized some code

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix splitkv error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update sink reference

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update fmha_fwd_runner.hpp

* Update smoke_test_fwd_sink.sh

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Linjun-AMD
2026-01-14 21:32:06 +08:00
committed by GitHub
parent 693ff3bbb3
commit 717ed0b59f
17 changed files with 487 additions and 110 deletions

View File

@@ -101,6 +101,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -346,12 +347,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
seqlen_q,
-1,
hdim_q,
@@ -491,12 +494,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
@@ -701,7 +706,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
const index_t seqlen_k = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
@@ -1226,7 +1234,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
dropout,
sink_value);
}
else
{
@@ -1248,7 +1257,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
dropout,
sink_value);
}
}();

View File

@@ -89,6 +89,7 @@ struct FmhaFwdKernel
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -343,12 +344,14 @@ struct FmhaFwdKernel
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
seqlen_q,
seqlen_k,
hdim_q,
@@ -490,7 +493,8 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -539,7 +543,8 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
cu_seqlen_k_ptr,
sink_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -591,7 +596,8 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -640,7 +646,8 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
cu_seqlen_k_ptr,
sink_ptr);
}
template <bool Cond = kIsGroupMode>
@@ -688,12 +695,14 @@ struct FmhaFwdKernel
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
@@ -833,7 +842,8 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -878,7 +888,8 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
cu_seqlen_k_ptr,
sink_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -926,7 +937,8 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -971,7 +983,8 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
cu_seqlen_k_ptr,
sink_ptr);
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
@@ -1093,10 +1106,8 @@ struct FmhaFwdKernel
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
@@ -1107,6 +1118,10 @@ struct FmhaFwdKernel
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
if constexpr(kIsGroupMode)
{
@@ -1525,7 +1540,6 @@ struct FmhaFwdKernel
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
@@ -1566,7 +1580,8 @@ struct FmhaFwdKernel
variant_params,
block_indices,
smem_ptr,
dropout);
dropout,
sink_value);
}
else
{
@@ -1583,7 +1598,8 @@ struct FmhaFwdKernel
variant_params,
block_indices,
smem_ptr,
dropout);
dropout,
sink_value);
}
}();
@@ -1623,6 +1639,10 @@ struct FmhaFwdKernel
constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
@@ -2275,6 +2295,7 @@ struct FmhaFwdKernel
mask,
position_encoding,
kargs.scale_s,
sink_value,
smem_ptrk0,
smem_ptrk1,
smem_ptrv0,
@@ -2291,7 +2312,8 @@ struct FmhaFwdKernel
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
smem_ptr,
sink_value);
}
}();

View File

@@ -123,6 +123,7 @@ struct FmhaFwdPagedKVKernel
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -328,12 +329,14 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
seqlen_q,
seqlen_k,
hdim_q,
@@ -457,7 +460,8 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
const void* sink_ptr = nullptr)
{
return MakeKargsImpl(q_ptr,
k_ptr,
@@ -500,7 +504,8 @@ struct FmhaFwdPagedKVKernel
window_size_left,
window_size_right,
sink_size,
mask_type);
mask_type,
sink_ptr);
}
template <bool Cond = kIsGroupMode>
@@ -543,12 +548,14 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
ck_tile::index_t min_seqlen_q,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
@@ -669,7 +676,8 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
ck_tile::index_t min_seqlen_q,
const void* sink_ptr = nullptr)
{
return MakeKargsImpl(q_ptr,
k_ptr,
@@ -709,7 +717,8 @@ struct FmhaFwdPagedKVKernel
window_size_right,
sink_size,
mask_type,
min_seqlen_q);
min_seqlen_q,
sink_ptr);
}
CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
@@ -898,7 +907,6 @@ struct FmhaFwdPagedKVKernel
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
@@ -909,6 +917,10 @@ struct FmhaFwdPagedKVKernel
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
index_t kv_l2p_offset = 0;
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
if constexpr(kIsGroupMode)
{
@@ -1350,7 +1362,8 @@ struct FmhaFwdPagedKVKernel
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_value);
}
else
{
@@ -1368,7 +1381,8 @@ struct FmhaFwdPagedKVKernel
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_value);
}
}();

View File

@@ -124,6 +124,7 @@ struct FmhaFwdSplitKVKernel
const void* v_ptr;
void* lse_acc_ptr;
void* o_acc_ptr;
const void* sink_ptr;
ck_tile::index_t batch;
@@ -327,13 +328,15 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
lse_acc_ptr,
o_acc_ptr,
sink_ptr,
batch,
seqlen_q,
seqlen_k,
@@ -455,13 +458,15 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
lse_acc_ptr,
o_acc_ptr,
sink_ptr,
batch,
-1, // seqlen_q will be updated by another pointer
-1, // seqlen_k will be updated by another pointer
@@ -530,7 +535,6 @@ struct FmhaFwdSplitKVKernel
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
return kargs;
}
@@ -615,6 +619,10 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_o_acc = 0;
index_t kv_l2p_offset =
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
if constexpr(kIsGroupMode)
{
@@ -698,7 +706,6 @@ struct FmhaFwdSplitKVKernel
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer
const index_t i_nhead_k =
(kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
@@ -1083,7 +1090,8 @@ struct FmhaFwdSplitKVKernel
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_value);
}
else
{
@@ -1103,7 +1111,8 @@ struct FmhaFwdSplitKVKernel
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_value);
}
}();

View File

@@ -191,6 +191,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
static constexpr auto LOG2E = log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
@@ -297,7 +298,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -383,8 +385,24 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(__builtin_isinf_sign(sink_v) >= 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * LOG2E * scale_s);
else
set_tile(m, sink_v * LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
@@ -403,7 +421,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -1049,7 +1074,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout) const
DropoutType& dropout,
float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -1077,7 +1103,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
stride_v,
page_stride_k,
page_stride_v,
dropout);
dropout,
sink_v);
}
};

View File

@@ -163,7 +163,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
void* smem_ptr,
const float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -227,8 +228,24 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(__builtin_isinf_sign(sink_v) >= 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
@@ -258,7 +275,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -788,7 +812,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
void* smem_ptr,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -812,7 +837,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_v);
}
};

View File

@@ -164,7 +164,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
void* smem_ptr,
float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -254,8 +255,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
{
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
@@ -285,7 +294,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
if(get_thread_local_1d_id() < kM0)
{
@@ -299,7 +315,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
return o_acc;
}
}
if(i_split > 0)
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
{
set_tile(m, SMPLComputeDataType{sink_v});
set_tile(l, SMPLComputeDataType{1.0f});
}
}
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
@@ -879,7 +904,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
void* smem_ptr,
float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -905,7 +931,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_v);
}
};

View File

@@ -163,7 +163,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
void* smem_ptr,
float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -227,8 +228,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
@@ -260,7 +277,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
@@ -272,6 +296,29 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
if(i_split > 0)
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
}
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
@@ -797,7 +844,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
void* smem_ptr,
float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -823,7 +871,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
smem_ptr,
sink_v);
}
};

View File

@@ -166,7 +166,8 @@ struct BlockFmhaPipelineQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -230,9 +231,24 @@ struct BlockFmhaPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(__builtin_isinf_sign(sink_v) >= 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
set_tile(m, sink_v * scale_s * C_LOG2E);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
@@ -265,7 +281,14 @@ struct BlockFmhaPipelineQRKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -798,7 +821,8 @@ struct BlockFmhaPipelineQRKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -821,7 +845,8 @@ struct BlockFmhaPipelineQRKSVS
variant_params,
block_indices,
smem_ptr,
dropout);
dropout,
sink_v);
}
};

View File

@@ -87,6 +87,7 @@ struct BlockFmhaPipelineQRKSVSAsync
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
static constexpr auto LOG2E = log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
@@ -188,7 +189,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -274,9 +276,24 @@ struct BlockFmhaPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(__builtin_isinf_sign(sink_v) >= 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
set_tile(m, sink_v * scale_s * LOG2E);
else
set_tile(m, sink_v * LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
@@ -309,7 +326,14 @@ struct BlockFmhaPipelineQRKSVSAsync
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -475,17 +499,10 @@ struct BlockFmhaPipelineQRKSVSAsync
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
}
else
{
@@ -880,7 +897,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -903,7 +921,8 @@ struct BlockFmhaPipelineQRKSVSAsync
variant_params,
block_indices,
smem_ptr,
dropout);
dropout,
sink_v);
}
};

View File

@@ -148,7 +148,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
void* smem_ptr,
float sink_v) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -193,8 +194,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(__builtin_isinf_sign(sink_v) >= 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
@@ -212,7 +229,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_acc_dram_window_tmp, lse_acc);
}
@@ -649,6 +673,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
float sink_v,
void* __restrict__ smem_ptrk0,
void* __restrict__ smem_ptrk1,
void* __restrict__ smem_ptrv0,
@@ -698,8 +723,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
if(__builtin_isinf_sign(sink_v) >= 0)
{
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
set_tile(m, sink_v * C_LOG2E * scale_s);
else
set_tile(m, sink_v * C_LOG2E);
#else
set_tile(m, sink_v);
#endif
set_tile(l, SMPLComputeDataType{1.0f});
}
else
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
@@ -717,7 +758,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
store_tile(lse_acc_dram_window_tmp, lse_acc);
}