mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE][FMHA] Enable gpt-oss sink (#3490)
* Enable gptoss sink
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* add gptoss sink test
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* update CHANGELOG.md
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix test args error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update test_fmha_fwd.cpp
* update sink test
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Revert "update sink test"
This reverts commit 970b4f1686.
* update sink test
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* update valid sink_v in splitkv pipeline
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
* Update example_fmha_fwd.cpp
* fix lse error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix clangformat error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix aiter scale error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update block_fmha_pipeline_qr_ks_vs.hpp
* div scale_s for sink_value
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update fmha_fwd_runner.hpp
* update sink_value with bias
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
* Fix typo in dropout parameter in fmha_batch_prefill_kernel
* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
* Update example_fmha_fwd.cpp
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* optimized some code
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* fix splitkv error
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* update sink reference
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
* Update fmha_fwd_runner.hpp
* Update smoke_test_fwd_sink.sh
---------
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -191,6 +191,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
static constexpr auto LOG2E = log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -297,7 +298,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -383,8 +385,24 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -403,7 +421,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -1049,7 +1074,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -1077,7 +1103,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -227,8 +228,24 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
@@ -258,7 +275,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -788,7 +812,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -812,7 +837,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -164,7 +164,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -254,8 +255,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
@@ -285,7 +294,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
@@ -299,7 +315,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
if(i_split > 0)
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
}
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -879,7 +904,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -905,7 +931,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -227,8 +228,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
@@ -260,7 +277,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
@@ -272,6 +296,29 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
if(i_split > 0)
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
}
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -797,7 +844,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -823,7 +871,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -166,7 +166,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -230,9 +231,24 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
set_tile(m, sink_v * scale_s * C_LOG2E);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
@@ -265,7 +281,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -798,7 +821,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -821,7 +845,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
static constexpr auto LOG2E = log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -188,7 +189,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -274,9 +276,24 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
set_tile(m, sink_v * scale_s * LOG2E);
|
||||
else
|
||||
set_tile(m, sink_v * LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
@@ -309,7 +326,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -475,17 +499,10 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -880,7 +897,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -903,7 +921,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -148,7 +148,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -193,8 +194,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
@@ -212,7 +229,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
@@ -649,6 +673,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
float sink_v,
|
||||
void* __restrict__ smem_ptrk0,
|
||||
void* __restrict__ smem_ptrk1,
|
||||
void* __restrict__ smem_ptrv0,
|
||||
@@ -698,8 +723,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
@@ -717,7 +758,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user