Support logits_soft_cap in batch_decode()

This commit is contained in:
Po Yen Chen
2025-04-21 23:33:52 +00:00
parent fd84cf840a
commit 9512f78616
3 changed files with 40 additions and 6 deletions

View File

@@ -915,6 +915,7 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
mask,
position_encoding,
kargs.scale_s,
kargs,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
@@ -932,6 +933,7 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
mask,
position_encoding,
kargs.scale_s,
kargs,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,

View File

@@ -136,7 +136,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename LogitsSoftCapParams>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -156,6 +157,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const LogitsSoftCapParams& logits_soft_cap_params,
void* smem_ptr,
const int32_t* kv_page_indices,
const index_t stride_k,
@@ -472,6 +474,19 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#else
if constexpr(kHasLogitsSoftCap)
{
float scale_lo =
scale_s * 0.6931472f * logits_soft_cap_params.logits_soft_cap_rcp;
float logits_cap =
log2e_v<SaccDataType> * logits_soft_cap_params.logits_soft_cap;
tile_elementwise_inout(
[&scale_lo, &logits_cap](auto& x) {
x = logits_cap * tanh_fast<SaccDataType>(x * scale_lo);
},
s_acc);
}
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -579,7 +594,14 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
@@ -608,8 +630,15 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
@@ -759,7 +788,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename LogitsSoftCapParams>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -771,6 +801,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const LogitsSoftCapParams& logits_soft_cap_params,
void* smem_ptr,
const int32_t* kv_page_indices,
const index_t stride_k,
@@ -794,6 +825,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
mask,
position_encoding,
scale_s,
logits_soft_cap_params,
smem_ptr,
kv_page_indices,
stride_k,

View File

@@ -112,7 +112,7 @@ struct BlockFmhaBatchPrefillWithPagedKVCachePipelineQRKSVS
else
{
return 1;
};
}
}
}();