mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE][FMHA] Enable gpt-oss sink (#3490)
* Enable gptoss sink
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* add gptoss sink test
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* update CHANGELOG.md
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix test args error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update test_fmha_fwd.cpp
* update sink test
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Revert "update sink test"
This reverts commit 970b4f1686.
* update sink test
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* update valid sink_v in splitkv pipeline
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
* Update example_fmha_fwd.cpp
* fix lse error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix clangformat error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix aiter scale error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update block_fmha_pipeline_qr_ks_vs.hpp
* div scale_s for sink_value
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update fmha_fwd_runner.hpp
* update sink_value with bias
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
* Fix typo in dropout parameter in fmha_batch_prefill_kernel
* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
* Update example_fmha_fwd.cpp
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* optimized some code
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix splitkv error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* update sink reference
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update fmha_fwd_runner.hpp
* Update smoke_test_fwd_sink.sh
---------
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -101,6 +101,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -346,12 +347,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
seqlen_q,
|
||||
-1,
|
||||
hdim_q,
|
||||
@@ -491,12 +494,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
@@ -701,7 +706,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
const index_t seqlen_k = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
@@ -1226,7 +1234,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1248,7 +1257,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -343,12 +344,14 @@ struct FmhaFwdKernel
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -490,7 +493,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -539,7 +543,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -591,7 +596,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -640,7 +646,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -688,12 +695,14 @@ struct FmhaFwdKernel
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
@@ -833,7 +842,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -878,7 +888,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -926,7 +937,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -971,7 +983,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
@@ -1093,10 +1106,8 @@ struct FmhaFwdKernel
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -1107,6 +1118,10 @@ struct FmhaFwdKernel
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -1525,7 +1540,6 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
@@ -1566,7 +1580,8 @@ struct FmhaFwdKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1583,7 +1598,8 @@ struct FmhaFwdKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1623,6 +1639,10 @@ struct FmhaFwdKernel
|
||||
constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
|
||||
const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
|
||||
@@ -2275,6 +2295,7 @@ struct FmhaFwdKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
sink_value,
|
||||
smem_ptrk0,
|
||||
smem_ptrk1,
|
||||
smem_ptrv0,
|
||||
@@ -2291,7 +2312,8 @@ struct FmhaFwdKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ struct FmhaFwdPagedKVKernel
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -328,12 +329,14 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -457,7 +460,8 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
@@ -500,7 +504,8 @@ struct FmhaFwdPagedKVKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type);
|
||||
mask_type,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -543,12 +548,14 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
@@ -669,7 +676,8 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
@@ -709,7 +717,8 @@ struct FmhaFwdPagedKVKernel
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q);
|
||||
min_seqlen_q,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
|
||||
@@ -898,7 +907,6 @@ struct FmhaFwdPagedKVKernel
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -909,6 +917,10 @@ struct FmhaFwdPagedKVKernel
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
index_t kv_l2p_offset = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -1350,7 +1362,8 @@ struct FmhaFwdPagedKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1368,7 +1381,8 @@ struct FmhaFwdPagedKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* v_ptr;
|
||||
void* lse_acc_ptr;
|
||||
void* o_acc_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
|
||||
@@ -327,13 +328,15 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
sink_ptr,
|
||||
batch,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
@@ -455,13 +458,15 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
sink_ptr,
|
||||
batch,
|
||||
-1, // seqlen_q will be updated by another pointer
|
||||
-1, // seqlen_k will be updated by another pointer
|
||||
@@ -530,7 +535,6 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -615,6 +619,10 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
index_t kv_l2p_offset =
|
||||
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -698,7 +706,6 @@ struct FmhaFwdSplitKVKernel
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const index_t i_nhead_k =
|
||||
(kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
|
||||
@@ -1083,7 +1090,8 @@ struct FmhaFwdSplitKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1103,7 +1111,8 @@ struct FmhaFwdSplitKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -191,6 +191,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
static constexpr auto LOG2E = log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -297,7 +298,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -383,8 +385,24 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -403,7 +421,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -1049,7 +1074,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -1077,7 +1103,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -227,8 +228,24 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
@@ -258,7 +275,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -788,7 +812,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -812,7 +837,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -164,7 +164,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -254,8 +255,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
@@ -285,7 +294,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
@@ -299,7 +315,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
if(i_split > 0)
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
}
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -879,7 +904,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -905,7 +931,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -227,8 +228,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
@@ -260,7 +277,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
@@ -272,6 +296,29 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
if(i_split > 0)
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
}
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -797,7 +844,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -823,7 +871,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -166,7 +166,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -230,9 +231,24 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
set_tile(m, sink_v * scale_s * C_LOG2E);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
@@ -265,7 +281,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -798,7 +821,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -821,7 +845,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
static constexpr auto LOG2E = log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -188,7 +189,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -274,9 +276,24 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
set_tile(m, sink_v * scale_s * LOG2E);
|
||||
else
|
||||
set_tile(m, sink_v * LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
@@ -309,7 +326,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -475,17 +499,10 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -880,7 +897,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -903,7 +921,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -148,7 +148,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -193,8 +194,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
@@ -212,7 +229,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
@@ -649,6 +673,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
float sink_v,
|
||||
void* __restrict__ smem_ptrk0,
|
||||
void* __restrict__ smem_ptrk1,
|
||||
void* __restrict__ smem_ptrv0,
|
||||
@@ -698,8 +723,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
@@ -717,7 +758,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user