mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[AITERKER-112] Add PER_TOKEN_HEAD FP8 quant scheme to batch_prefill
- New BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD enum value - Pipeline overload in block_fmha_batch_prefill_pipeline_qr_ks_vs_async applying per-token Q/K descale via GEMM0-post outer product and per-head V descale at epilogue - fmha_batch_prefill_kernel kargs + MakeKargs + pipeline dispatch - fmha_fwd.hpp host-side traits/args wiring - quant.hpp trait specialization - Codegen emits PER_TOKEN_HEAD kernel variants
This commit is contained in:
@@ -81,6 +81,7 @@ QSCALE_MAP = {
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
"per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD",
|
||||
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
|
||||
}
|
||||
|
||||
@@ -89,6 +90,7 @@ QSCALE_CHECK_MAP = {
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
"per_token_head": "quant_scale_enum::per_token_head",
|
||||
"mx": "quant_scale_enum::mx",
|
||||
}
|
||||
|
||||
|
||||
@@ -733,7 +733,7 @@ class KernelComponentFactory:
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
["pertensor", "kv_blockscale", "per_token_head"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["t", "f"],
|
||||
@@ -819,9 +819,12 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
if (
|
||||
pipeline.F_qscale in ("kv_blockscale", "per_token_head")
|
||||
and page_size < tile.F_bn0
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
|
||||
@@ -671,6 +671,19 @@ struct fmha_batch_prefill_args
|
||||
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
|
||||
// PER_TOKEN_HEAD: Q/K per-token per-head, V per-head (FP8 fine-grained).
|
||||
// q_descale_ptr/k_descale_ptr/v_descale_ptr are reused; layout:
|
||||
// q_descale: [total_q_tokens, nhead_q] fp32
|
||||
// k_descale: [num_total_pages, page_block_size, nhead_k] fp32
|
||||
// (aligned with paged K cache so we can reuse k_physical_pages[])
|
||||
// v_descale: [nhead_k] fp32
|
||||
ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token)
|
||||
ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride
|
||||
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
|
||||
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
|
||||
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
|
||||
};
|
||||
|
||||
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
|
||||
@@ -1340,7 +1353,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
args.nhead_stride_kv_block_descale,
|
||||
args.stride_q_descale_token,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nblock_stride_k_descale_page,
|
||||
args.stride_k_descale_token,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1395,7 +1414,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
args.nhead_stride_kv_block_descale,
|
||||
args.stride_q_descale_token,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nblock_stride_k_descale_page,
|
||||
args.stride_k_descale_token,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
mx = 4, // Microscaling (MX)
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
mx = 4, // Microscaling (MX)
|
||||
per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -37,6 +38,8 @@ struct quant_scale_info
|
||||
os << "kvbs";
|
||||
else if(type == quant_scale_enum::mx)
|
||||
os << "mx";
|
||||
else if(type == quant_scale_enum::per_token_head)
|
||||
os << "pth";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -62,6 +65,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::mx;
|
||||
}
|
||||
else if(str == "pth" || str == "5")
|
||||
{
|
||||
info.type = quant_scale_enum::per_token_head;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
@@ -10,11 +10,12 @@ namespace ck_tile {
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE = 2,
|
||||
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
|
||||
MX = 4, // Microscaling
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE = 2,
|
||||
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
|
||||
MX = 4, // Microscaling
|
||||
PER_TOKEN_HEAD = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
@@ -45,5 +46,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::MX>
|
||||
{
|
||||
static constexpr const char* name = "mx";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD>
|
||||
{
|
||||
static constexpr const char* name = "per_token_head";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -205,6 +205,23 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
// PER_TOKEN_HEAD: Q per-token-per-head, K per-token-per-head (paged-aligned), V per-head
|
||||
// q_descale: [total_q, nhead_q]
|
||||
// k_descale: [num_total_pages, page_block_size, nhead_k]
|
||||
// v_descale: [nhead_k]
|
||||
struct FmhaFwdPerTokenHeadKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
ck_tile::index_t stride_q_descale_token = 0;
|
||||
ck_tile::index_t nhead_stride_q_descale = 0;
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0;
|
||||
ck_tile::index_t stride_k_descale_token = 0;
|
||||
ck_tile::index_t nhead_stride_k_descale = 0;
|
||||
ck_tile::index_t nhead_stride_v_descale = 0;
|
||||
};
|
||||
|
||||
// Helper template to select QScale Kargs type based on QScaleEnum
|
||||
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
|
||||
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
|
||||
@@ -225,6 +242,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
using type = FmhaFwdKVBlockScaleKargs;
|
||||
};
|
||||
|
||||
template <typename EmptyType>
|
||||
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD, EmptyType>
|
||||
{
|
||||
using type = FmhaFwdPerTokenHeadKargs;
|
||||
};
|
||||
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
template <typename T>
|
||||
@@ -379,7 +402,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0,
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0)
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0,
|
||||
// PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD)
|
||||
ck_tile::index_t stride_q_descale_token = 0,
|
||||
ck_tile::index_t nhead_stride_q_descale = 0,
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0,
|
||||
ck_tile::index_t stride_k_descale_token = 0,
|
||||
ck_tile::index_t nhead_stride_k_descale = 0,
|
||||
ck_tile::index_t nhead_stride_v_descale = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -458,6 +488,18 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
|
||||
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
kargs.stride_q_descale_token = stride_q_descale_token;
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page;
|
||||
kargs.stride_k_descale_token = stride_k_descale_token;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -536,7 +578,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0,
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0)
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0,
|
||||
// PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD)
|
||||
ck_tile::index_t stride_q_descale_token = 0,
|
||||
ck_tile::index_t nhead_stride_q_descale = 0,
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0,
|
||||
ck_tile::index_t stride_k_descale_token = 0,
|
||||
ck_tile::index_t nhead_stride_k_descale = 0,
|
||||
ck_tile::index_t nhead_stride_v_descale = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -612,6 +661,18 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
|
||||
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
kargs.stride_q_descale_token = stride_q_descale_token;
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page;
|
||||
kargs.stride_k_descale_token = stride_k_descale_token;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -1222,6 +1283,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
return kargs.scale_s * q_descale;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// Q/K descales are per-token-per-head, applied as outer product in pipeline.
|
||||
// Here we only forward the softmax scale (1/sqrt(d)).
|
||||
return kargs.scale_s;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
@@ -1339,6 +1406,47 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.nblock_stride_kv_block_descale,
|
||||
kargs.nhead_stride_kv_block_descale);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD: Q/K descales are per-token-per-head, V is per-head.
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
const float* q_descale_ptr = reinterpret_cast<const float*>(kargs.q_descale_ptr);
|
||||
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
|
||||
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout,
|
||||
sink_value,
|
||||
max_page_table_idx,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.stride_q_descale_token,
|
||||
kargs.nhead_stride_q_descale,
|
||||
kargs.nblock_stride_k_descale_page,
|
||||
kargs.stride_k_descale_token,
|
||||
kargs.nhead_stride_k_descale,
|
||||
kargs.nhead_stride_v_descale);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
|
||||
@@ -436,7 +436,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
index_t nblock_stride_kv_block_descale = 0,
|
||||
index_t nhead_stride_kv_block_descale = 0) const
|
||||
index_t nhead_stride_kv_block_descale = 0,
|
||||
// PER_TOKEN_HEAD parameters (only used when QScaleEnum == PER_TOKEN_HEAD)
|
||||
// Reuses k_descale_ptr / v_descale_ptr above; q_descale provided here.
|
||||
// Layouts:
|
||||
// q_descale_per_token_ptr: [total_q, nhead_q]
|
||||
// k_descale_ptr (when PER_TOKEN_HEAD): [num_total_pages, page_block_size, nhead_k]
|
||||
// v_descale_ptr (when PER_TOKEN_HEAD): [nhead_k]
|
||||
const float* q_descale_per_token_ptr = nullptr,
|
||||
index_t stride_q_descale_token = 0,
|
||||
index_t nhead_stride_q_descale = 0,
|
||||
index_t nblock_stride_k_descale_page = 0,
|
||||
index_t stride_k_descale_token = 0,
|
||||
index_t nhead_stride_k_descale = 0,
|
||||
index_t nhead_stride_v_descale = 0) const
|
||||
{
|
||||
// KV_BLOCKSCALE requires page_block_size >= kN0 to ensure
|
||||
// all tokens in a main loop iteration belong to the same page
|
||||
@@ -444,6 +457,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0");
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
static_assert(kPageBlockSize >= kN0,
|
||||
"PER_TOKEN_HEAD requires kPageBlockSize >= kN0");
|
||||
}
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -1027,6 +1045,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
k_descale = k_descale_ptr[scale_offset];
|
||||
v_descale = v_descale_ptr[scale_offset];
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// V scale is per-head only; load scalar from v_descale_ptr[kv_head_idx].
|
||||
// K scale is per-token-per-head and is applied as a vector after GEMM0
|
||||
// (see PER_TOKEN_HEAD branch below).
|
||||
v_descale = v_descale_ptr[block_indices.kv_head_idx * nhead_stride_v_descale];
|
||||
}
|
||||
|
||||
// Prefetch V physical pages early - overlaps with GEMM0 computation
|
||||
save_and_prefetch_v_pages(number<kK1>{});
|
||||
@@ -1087,6 +1112,37 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
|
||||
}
|
||||
// PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale.
|
||||
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page, k_slot+j, kv_head]
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
const index_t k_page = k_physical_pages[number<0>{}];
|
||||
const index_t k_slot_base = k_origin.at(number<0>{}) % kPageBlockSize;
|
||||
const index_t qo_head = block_indices.qo_head_idx;
|
||||
const index_t kv_head = block_indices.kv_head_idx;
|
||||
const index_t q_row_base = q_origin.at(number<0>{});
|
||||
|
||||
const index_t k_page_base = k_page * nblock_stride_k_descale_page +
|
||||
kv_head * nhead_stride_k_descale;
|
||||
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
const index_t i = tile_idx.at(number<0>{});
|
||||
const index_t j = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const float qd = q_descale_per_token_ptr[
|
||||
(q_row_base + i) * stride_q_descale_token +
|
||||
qo_head * nhead_stride_q_descale];
|
||||
const float kd = k_descale_ptr[
|
||||
k_page_base + (k_slot_base + j) * stride_k_descale_token];
|
||||
s_acc(i_j_idx) *= qd * kd;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
@@ -1309,7 +1365,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
@@ -1427,7 +1484,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// STAGE 3, KV gemm
|
||||
// KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale
|
||||
auto o_acc_unscaled = decltype(o_acc){};
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
clear_tile(o_acc_unscaled);
|
||||
}
|
||||
@@ -1435,7 +1493,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc
|
||||
// otherwise
|
||||
auto& gemm1_acc = [&]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
return o_acc_unscaled;
|
||||
else
|
||||
return o_acc;
|
||||
@@ -1586,7 +1645,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// 1. P was scaled by 2^shift through exp2 shift trick
|
||||
// 2. rowsum l was also scaled by 2^shift
|
||||
// 3. Final O = sum(P*V) / l, so the 2^shift cancels out
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; },
|
||||
@@ -1787,6 +1847,90 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
nblock_stride_kv_block_descale,
|
||||
nhead_stride_kv_block_descale);
|
||||
}
|
||||
|
||||
// Overload for PER_TOKEN_HEAD: Q/K per-token-per-head, V per-head
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout,
|
||||
float sink_v,
|
||||
const index_t max_page_table_idx,
|
||||
const float* q_descale_per_token_ptr,
|
||||
const float* k_descale_per_token_ptr,
|
||||
const float* v_descale_per_head_ptr,
|
||||
index_t stride_q_descale_token,
|
||||
index_t nhead_stride_q_descale,
|
||||
index_t nblock_stride_k_descale_page,
|
||||
index_t stride_k_descale_token,
|
||||
index_t nhead_stride_k_descale,
|
||||
index_t nhead_stride_v_descale) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout,
|
||||
sink_v,
|
||||
max_page_table_idx,
|
||||
k_descale_per_token_ptr, // reused: k_descale_ptr slot
|
||||
v_descale_per_head_ptr, // reused: v_descale_ptr slot
|
||||
/*nblock_stride_kv_block_descale*/ 0,
|
||||
/*nhead_stride_kv_block_descale*/ 0,
|
||||
q_descale_per_token_ptr,
|
||||
stride_q_descale_token,
|
||||
nhead_stride_q_descale,
|
||||
nblock_stride_k_descale_page,
|
||||
stride_k_descale_token,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user