mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
p scale support
revert comments
This commit is contained in:
@@ -679,9 +679,19 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
|
||||
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
|
||||
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
|
||||
|
||||
// PER_TOKEN_HEAD optional per-q-head P scale [num_head_q] fp32.
|
||||
const void* p_scale_ptr = nullptr;
|
||||
};
|
||||
|
||||
// Select KV-cache load mode for batch-prefill.
|
||||
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
|
||||
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
|
||||
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
|
||||
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
|
||||
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
|
||||
// Inputs are taken as plain integers so the helper has no template parameter
|
||||
// and can be called from each codegen-emitted dispatcher arm with the arm's
|
||||
// compile-time kN0 / element_bytes substituted as constants.
|
||||
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
|
||||
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t kN0,
|
||||
@@ -689,7 +699,10 @@ fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t element_bytes)
|
||||
{
|
||||
// Promote all operands before multiply to avoid intermediate overflow.
|
||||
// Promote every operand to long_index_t so overflow is impossible regardless
|
||||
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
|
||||
// * batch_stride_k * element_bytes` only works because of left-to-right
|
||||
// associativity — a future reorder of the operands would silently truncate.
|
||||
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
|
||||
static_cast<ck_tile::long_index_t>(batch_stride_k) *
|
||||
static_cast<ck_tile::long_index_t>(element_bytes);
|
||||
@@ -1344,7 +1357,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nblock_stride_k_descale_page,
|
||||
args.stride_k_descale_token,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale);
|
||||
args.nhead_stride_v_descale,
|
||||
args.p_scale_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1405,7 +1419,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nblock_stride_k_descale_page,
|
||||
args.stride_k_descale_token,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale);
|
||||
args.nhead_stride_v_descale,
|
||||
args.p_scale_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -220,6 +220,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t stride_k_descale_token = 0;
|
||||
ck_tile::index_t nhead_stride_k_descale = 0;
|
||||
ck_tile::index_t nhead_stride_v_descale = 0;
|
||||
// Optional per-q-head P scale [num_head_q] fp32.
|
||||
const void* p_scale_ptr = nullptr;
|
||||
};
|
||||
|
||||
// Helper template to select QScale Kargs type based on QScaleEnum
|
||||
@@ -409,7 +411,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0,
|
||||
ck_tile::index_t stride_k_descale_token = 0,
|
||||
ck_tile::index_t nhead_stride_k_descale = 0,
|
||||
ck_tile::index_t nhead_stride_v_descale = 0)
|
||||
ck_tile::index_t nhead_stride_v_descale = 0,
|
||||
const void* p_scale_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -499,6 +502,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.stride_k_descale_token = stride_k_descale_token;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
kargs.p_scale_ptr = p_scale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -585,7 +589,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0,
|
||||
ck_tile::index_t stride_k_descale_token = 0,
|
||||
ck_tile::index_t nhead_stride_k_descale = 0,
|
||||
ck_tile::index_t nhead_stride_v_descale = 0)
|
||||
ck_tile::index_t nhead_stride_v_descale = 0,
|
||||
const void* p_scale_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -672,6 +677,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.stride_k_descale_token = stride_k_descale_token;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
kargs.p_scale_ptr = p_scale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -1416,6 +1422,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
|
||||
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
|
||||
|
||||
const float* p_scale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.p_scale_ptr);
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
@@ -1445,7 +1454,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.nblock_stride_k_descale_page,
|
||||
kargs.stride_k_descale_token,
|
||||
kargs.nhead_stride_k_descale,
|
||||
kargs.nhead_stride_v_descale);
|
||||
kargs.nhead_stride_v_descale,
|
||||
p_scale_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -449,7 +449,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
index_t nblock_stride_k_descale_page = 0,
|
||||
index_t stride_k_descale_token = 0,
|
||||
index_t nhead_stride_k_descale = 0,
|
||||
index_t nhead_stride_v_descale = 0) const
|
||||
index_t nhead_stride_v_descale = 0,
|
||||
// PER_TOKEN_HEAD caller P scale [num_head_q] fp32, optional.
|
||||
// Folded into the exp2 row-max shift; the rowsum carries the
|
||||
// same factor so it cancels in O/l with no v_descale fixup.
|
||||
const float* p_scale_ptr = nullptr) const
|
||||
{
|
||||
// KV_BLOCKSCALE requires page_block_size >= kN0 to ensure
|
||||
// all tokens in a main loop iteration belong to the same page
|
||||
@@ -462,6 +466,17 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// below precomputes per-(kPageBlockSize)-wide-slice physical page IDs
|
||||
// and applies them per column.
|
||||
|
||||
// Per-q-head P scale (PER_TOKEN_HEAD only): folded into the exp2
|
||||
// row-max shift below; null pointer reduces to the default 2^shift.
|
||||
float p_scale_log2 = 0.f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
if(p_scale_ptr != nullptr)
|
||||
{
|
||||
p_scale_log2 = log2f(p_scale_ptr[block_indices.qo_head_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
@@ -1463,21 +1478,27 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
// For KV_BLOCKSCALE: precompute (m - shift) once per row
|
||||
// exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift
|
||||
// This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply
|
||||
// exp2(s - (m - shift)) = exp2(s - m) * 2^shift, i.e. P is
|
||||
// scaled by 2^shift before the fp8 cast: OCP fp8 (e4m3)
|
||||
// uses 8 -> 256, FNUZ fp8 uses 7 -> 128.
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
validated_m -= OCP_FP8_SHIFT;
|
||||
row_max -= OCP_FP8_SHIFT;
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
if constexpr(QScaleEnum ==
|
||||
BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
validated_m -= p_scale_log2;
|
||||
row_max -= p_scale_log2;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
@@ -1991,7 +2012,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
index_t nblock_stride_k_descale_page,
|
||||
index_t stride_k_descale_token,
|
||||
index_t nhead_stride_k_descale,
|
||||
index_t nhead_stride_v_descale) const
|
||||
index_t nhead_stride_v_descale,
|
||||
const float* p_scale_ptr = nullptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -2032,7 +2054,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
nblock_stride_k_descale_page,
|
||||
stride_k_descale_token,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale);
|
||||
nhead_stride_v_descale,
|
||||
p_scale_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user