p scale support

revert comments
This commit is contained in:
msaffari-amd
2026-05-28 13:26:37 +00:00
committed by 123
parent a6421a8d55
commit a66068e978
3 changed files with 63 additions and 15 deletions

View File

@@ -679,9 +679,19 @@ struct fmha_batch_prefill_args
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
// PER_TOKEN_HEAD optional per-q-head P scale [num_head_q] fp32.
const void* p_scale_ptr = nullptr;
};
// Select KV-cache load mode for batch-prefill.
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
// Inputs are taken as plain integers so the helper has no template parameter
// and can be called from each codegen-emitted dispatcher arm with the arm's
// compile-time kN0 / element_bytes substituted as constants.
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
ck_tile::index_t kN0,
@@ -689,7 +699,10 @@ fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
ck_tile::index_t batch_stride_k,
ck_tile::index_t element_bytes)
{
// Promote all operands before multiply to avoid intermediate overflow.
// Promote every operand to long_index_t so overflow is impossible regardless
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
// * batch_stride_k * element_bytes` only works because of left-to-right
// associativity — a future reorder of the operands would silently truncate.
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
static_cast<ck_tile::long_index_t>(batch_stride_k) *
static_cast<ck_tile::long_index_t>(element_bytes);
@@ -1344,7 +1357,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale);
args.nhead_stride_v_descale,
args.p_scale_ptr);
}
else
{ // create batch mode kernel arguments
@@ -1405,7 +1419,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale);
args.nhead_stride_v_descale,
args.p_scale_ptr);
}
}();

View File

@@ -220,6 +220,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t stride_k_descale_token = 0;
ck_tile::index_t nhead_stride_k_descale = 0;
ck_tile::index_t nhead_stride_v_descale = 0;
// Optional per-q-head P scale [num_head_q] fp32.
const void* p_scale_ptr = nullptr;
};
// Helper template to select QScale Kargs type based on QScaleEnum
@@ -409,7 +411,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t nblock_stride_k_descale_page = 0,
ck_tile::index_t stride_k_descale_token = 0,
ck_tile::index_t nhead_stride_k_descale = 0,
ck_tile::index_t nhead_stride_v_descale = 0)
ck_tile::index_t nhead_stride_v_descale = 0,
const void* p_scale_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -499,6 +502,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.stride_k_descale_token = stride_k_descale_token;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.p_scale_ptr = p_scale_ptr;
}
if constexpr(kHasDropout)
{
@@ -585,7 +589,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t nblock_stride_k_descale_page = 0,
ck_tile::index_t stride_k_descale_token = 0,
ck_tile::index_t nhead_stride_k_descale = 0,
ck_tile::index_t nhead_stride_v_descale = 0)
ck_tile::index_t nhead_stride_v_descale = 0,
const void* p_scale_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -672,6 +677,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.stride_k_descale_token = stride_k_descale_token;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.p_scale_ptr = p_scale_ptr;
}
if constexpr(kHasDropout)
{
@@ -1416,6 +1422,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
const float* p_scale_ptr =
reinterpret_cast<const float*>(kargs.p_scale_ptr);
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
@@ -1445,7 +1454,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.nblock_stride_k_descale_page,
kargs.stride_k_descale_token,
kargs.nhead_stride_k_descale,
kargs.nhead_stride_v_descale);
kargs.nhead_stride_v_descale,
p_scale_ptr);
}
else
{

View File

@@ -449,7 +449,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
index_t nblock_stride_k_descale_page = 0,
index_t stride_k_descale_token = 0,
index_t nhead_stride_k_descale = 0,
index_t nhead_stride_v_descale = 0) const
index_t nhead_stride_v_descale = 0,
// PER_TOKEN_HEAD caller P scale [num_head_q] fp32, optional.
// Folded into the exp2 row-max shift; the rowsum carries the
// same factor so it cancels in O/l with no v_descale fixup.
const float* p_scale_ptr = nullptr) const
{
// KV_BLOCKSCALE requires page_block_size >= kN0 to ensure
// all tokens in a main loop iteration belong to the same page
@@ -462,6 +466,17 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// below precomputes per-(kPageBlockSize)-wide-slice physical page IDs
// and applies them per column.
// Per-q-head P scale (PER_TOKEN_HEAD only): folded into the exp2
// row-max shift below; null pointer reduces to the default 2^shift.
float p_scale_log2 = 0.f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
if(p_scale_ptr != nullptr)
{
p_scale_log2 = log2f(p_scale_ptr[block_indices.qo_head_idx]);
}
}
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
@@ -1463,21 +1478,27 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
// For KV_BLOCKSCALE: precompute (m - shift) once per row
// exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift
// This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply
// exp2(s - (m - shift)) = exp2(s - m) * 2^shift, i.e. P is
// scaled by 2^shift before the fp8 cast: OCP fp8 (e4m3)
// uses 8 -> 256, FNUZ fp8 uses 7 -> 128.
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
row_max -= OCP_FP8_SHIFT; // for else branch
validated_m -= OCP_FP8_SHIFT;
row_max -= OCP_FP8_SHIFT;
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
if constexpr(QScaleEnum ==
BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
validated_m -= p_scale_log2;
row_max -= p_scale_log2;
}
}
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
@@ -1991,7 +2012,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
index_t nblock_stride_k_descale_page,
index_t stride_k_descale_token,
index_t nhead_stride_k_descale,
index_t nhead_stride_v_descale) const
index_t nhead_stride_v_descale,
const float* p_scale_ptr = nullptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -2032,7 +2054,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
nblock_stride_k_descale_page,
stride_k_descale_token,
nhead_stride_k_descale,
nhead_stride_v_descale);
nhead_stride_v_descale,
p_scale_ptr);
}
};