mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Support logits_soft_cap in batch_decode()
This commit is contained in:
@@ -915,6 +915,7 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
kargs,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
@@ -932,6 +933,7 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
kargs,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
|
||||
@@ -136,7 +136,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename LogitsSoftCapParams>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -156,6 +157,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const LogitsSoftCapParams& logits_soft_cap_params,
|
||||
void* smem_ptr,
|
||||
const int32_t* kv_page_indices,
|
||||
const index_t stride_k,
|
||||
@@ -472,6 +474,19 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#else
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
float scale_lo =
|
||||
scale_s * 0.6931472f * logits_soft_cap_params.logits_soft_cap_rcp;
|
||||
float logits_cap =
|
||||
log2e_v<SaccDataType> * logits_soft_cap_params.logits_soft_cap;
|
||||
tile_elementwise_inout(
|
||||
[&scale_lo, &logits_cap](auto& x) {
|
||||
x = logits_cap * tanh_fast<SaccDataType>(x * scale_lo);
|
||||
},
|
||||
s_acc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
@@ -579,7 +594,14 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
@@ -608,8 +630,15 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
@@ -759,7 +788,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename LogitsSoftCapParams>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -771,6 +801,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const LogitsSoftCapParams& logits_soft_cap_params,
|
||||
void* smem_ptr,
|
||||
const int32_t* kv_page_indices,
|
||||
const index_t stride_k,
|
||||
@@ -794,6 +825,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
logits_soft_cap_params,
|
||||
smem_ptr,
|
||||
kv_page_indices,
|
||||
stride_k,
|
||||
|
||||
@@ -112,7 +112,7 @@ struct BlockFmhaBatchPrefillWithPagedKVCachePipelineQRKSVS
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user