[AITERKER-112] Add PER_TOKEN_HEAD FP8 quant scheme to batch_prefill

- New BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD enum value
- Pipeline overload in block_fmha_batch_prefill_pipeline_qr_ks_vs_async
  applying per-token Q/K descale via GEMM0-post outer product and
  per-head V descale at epilogue
- fmha_batch_prefill_kernel kargs + MakeKargs + pipeline dispatch
- fmha_fwd.hpp host-side traits/args wiring
- quant.hpp trait specialization
- Codegen emits PER_TOKEN_HEAD kernel variants
This commit is contained in:
msaffari-amd
2026-05-19 15:41:32 +00:00
parent 83566edb0f
commit 403d99124d
7 changed files with 317 additions and 22 deletions

View File

@@ -81,6 +81,7 @@ QSCALE_MAP = {
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
"per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD",
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
}
@@ -89,6 +90,7 @@ QSCALE_CHECK_MAP = {
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
"per_token_head": "quant_scale_enum::per_token_head",
"mx": "quant_scale_enum::mx",
}

View File

@@ -733,7 +733,7 @@ class KernelComponentFactory:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor", "kv_blockscale"],
["pertensor", "kv_blockscale", "per_token_head"],
get_mask_map(mask_impl).keys(),
["no"],
["t", "f"],
@@ -819,9 +819,12 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
if (
pipeline.F_qscale in ("kv_blockscale", "per_token_head")
and page_size < tile.F_bn0
):
continue
k = FmhaFwdKernel(
F_idx=0,

View File

@@ -671,6 +671,19 @@ struct fmha_batch_prefill_args
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
// PER_TOKEN_HEAD: Q/K per-token per-head, V per-head (FP8 fine-grained).
// q_descale_ptr/k_descale_ptr/v_descale_ptr are reused; layout:
// q_descale: [total_q_tokens, nhead_q] fp32
// k_descale: [num_total_pages, page_block_size, nhead_k] fp32
// (aligned with paged K cache so we can reuse k_physical_pages[])
// v_descale: [nhead_k] fp32
ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token)
ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride
ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
};
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
@@ -1340,7 +1353,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
args.nhead_stride_kv_block_descale,
args.stride_q_descale_token,
args.nhead_stride_q_descale,
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale);
}
else
{ // create batch mode kernel arguments
@@ -1395,7 +1414,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
args.nhead_stride_kv_block_descale,
args.stride_q_descale_token,
args.nhead_stride_q_descale,
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale);
}
}();

View File

@@ -14,11 +14,12 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
};
struct quant_scale_info
@@ -37,6 +38,8 @@ struct quant_scale_info
os << "kvbs";
else if(type == quant_scale_enum::mx)
os << "mx";
else if(type == quant_scale_enum::per_token_head)
os << "pth";
}
static quant_scale_info decode(std::string str)
@@ -62,6 +65,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::mx;
}
else if(str == "pth" || str == "5")
{
info.type = quant_scale_enum::per_token_head;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);

View File

@@ -10,11 +10,12 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
MX = 4, // Microscaling
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
MX = 4, // Microscaling
PER_TOKEN_HEAD = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
};
template <BlockAttentionQuantScaleEnum>
@@ -45,5 +46,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::MX>
{
static constexpr const char* name = "mx";
};
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD>
{
static constexpr const char* name = "per_token_head";
};
} // namespace ck_tile

View File

@@ -205,6 +205,23 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
// PER_TOKEN_HEAD: Q per-token-per-head, K per-token-per-head (paged-aligned), V per-head
// q_descale: [total_q, nhead_q]
// k_descale: [num_total_pages, page_block_size, nhead_k]
// v_descale: [nhead_k]
struct FmhaFwdPerTokenHeadKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
ck_tile::index_t stride_q_descale_token = 0;
ck_tile::index_t nhead_stride_q_descale = 0;
ck_tile::index_t nblock_stride_k_descale_page = 0;
ck_tile::index_t stride_k_descale_token = 0;
ck_tile::index_t nhead_stride_k_descale = 0;
ck_tile::index_t nhead_stride_v_descale = 0;
};
// Helper template to select QScale Kargs type based on QScaleEnum
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
@@ -225,6 +242,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
using type = FmhaFwdKVBlockScaleKargs;
};
template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD, EmptyType>
{
using type = FmhaFwdPerTokenHeadKargs;
};
struct FmhaFwdDropoutSeedOffset
{
template <typename T>
@@ -379,7 +402,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
drop_seed_offset,
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
ck_tile::index_t nhead_stride_kv_block_descale = 0,
// PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD)
ck_tile::index_t stride_q_descale_token = 0,
ck_tile::index_t nhead_stride_q_descale = 0,
ck_tile::index_t nblock_stride_k_descale_page = 0,
ck_tile::index_t stride_k_descale_token = 0,
ck_tile::index_t nhead_stride_k_descale = 0,
ck_tile::index_t nhead_stride_v_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -458,6 +488,18 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.stride_q_descale_token = stride_q_descale_token;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page;
kargs.stride_k_descale_token = stride_k_descale_token;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -536,7 +578,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
drop_seed_offset,
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
ck_tile::index_t nhead_stride_kv_block_descale = 0,
// PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD)
ck_tile::index_t stride_q_descale_token = 0,
ck_tile::index_t nhead_stride_q_descale = 0,
ck_tile::index_t nblock_stride_k_descale_page = 0,
ck_tile::index_t stride_k_descale_token = 0,
ck_tile::index_t nhead_stride_k_descale = 0,
ck_tile::index_t nhead_stride_v_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -612,6 +661,18 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.stride_q_descale_token = stride_q_descale_token;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page;
kargs.stride_k_descale_token = stride_k_descale_token;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1222,6 +1283,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
return kargs.scale_s * q_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// Q/K descales are per-token-per-head, applied as outer product in pipeline.
// Here we only forward the softmax scale (1/sqrt(d)).
return kargs.scale_s;
}
else
{
return kargs.scale_s;
@@ -1339,6 +1406,47 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.nblock_stride_kv_block_descale,
kargs.nhead_stride_kv_block_descale);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD: Q/K descales are per-token-per-head, V is per-head.
assert(kargs.q_descale_ptr != nullptr);
assert(kargs.k_descale_ptr != nullptr);
assert(kargs.v_descale_ptr != nullptr);
const float* q_descale_ptr = reinterpret_cast<const float*>(kargs.q_descale_ptr);
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value,
max_page_table_idx,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
kargs.stride_q_descale_token,
kargs.nhead_stride_q_descale,
kargs.nblock_stride_k_descale_page,
kargs.stride_k_descale_token,
kargs.nhead_stride_k_descale,
kargs.nhead_stride_v_descale);
}
else
{
return FmhaPipeline{}(q_dram_window,

View File

@@ -436,7 +436,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
index_t nblock_stride_kv_block_descale = 0,
index_t nhead_stride_kv_block_descale = 0) const
index_t nhead_stride_kv_block_descale = 0,
// PER_TOKEN_HEAD parameters (only used when QScaleEnum == PER_TOKEN_HEAD)
// Reuses k_descale_ptr / v_descale_ptr above; q_descale provided here.
// Layouts:
// q_descale_per_token_ptr: [total_q, nhead_q]
// k_descale_ptr (when PER_TOKEN_HEAD): [num_total_pages, page_block_size, nhead_k]
// v_descale_ptr (when PER_TOKEN_HEAD): [nhead_k]
const float* q_descale_per_token_ptr = nullptr,
index_t stride_q_descale_token = 0,
index_t nhead_stride_q_descale = 0,
index_t nblock_stride_k_descale_page = 0,
index_t stride_k_descale_token = 0,
index_t nhead_stride_k_descale = 0,
index_t nhead_stride_v_descale = 0) const
{
// KV_BLOCKSCALE requires page_block_size >= kN0 to ensure
// all tokens in a main loop iteration belong to the same page
@@ -444,6 +457,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0");
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
static_assert(kPageBlockSize >= kN0,
"PER_TOKEN_HEAD requires kPageBlockSize >= kN0");
}
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -1027,6 +1045,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
k_descale = k_descale_ptr[scale_offset];
v_descale = v_descale_ptr[scale_offset];
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// V scale is per-head only; load scalar from v_descale_ptr[kv_head_idx].
// K scale is per-token-per-head and is applied as a vector after GEMM0
// (see PER_TOKEN_HEAD branch below).
v_descale = v_descale_ptr[block_indices.kv_head_idx * nhead_stride_v_descale];
}
// Prefetch V physical pages early - overlaps with GEMM0 computation
save_and_prefetch_v_pages(number<kK1>{});
@@ -1087,6 +1112,37 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
}
// PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale.
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page, k_slot+j, kv_head]
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
const auto k_origin = k_dram_block_window.get_window_origin();
const index_t k_page = k_physical_pages[number<0>{}];
const index_t k_slot_base = k_origin.at(number<0>{}) % kPageBlockSize;
const index_t qo_head = block_indices.qo_head_idx;
const index_t kv_head = block_indices.kv_head_idx;
const index_t q_row_base = q_origin.at(number<0>{});
const index_t k_page_base = k_page * nblock_stride_k_descale_page +
kv_head * nhead_stride_k_descale;
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const index_t i = tile_idx.at(number<0>{});
const index_t j = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const float qd = q_descale_per_token_ptr[
(q_row_base + i) * stride_q_descale_token +
qo_head * nhead_stride_q_descale];
const float kd = k_descale_ptr[
k_page_base + (k_slot_base + j) * stride_k_descale_token];
s_acc(i_j_idx) *= qd * kd;
});
});
}
const auto p = [&]() {
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
@@ -1309,7 +1365,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
@@ -1427,7 +1484,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// STAGE 3, KV gemm
// KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale
auto o_acc_unscaled = decltype(o_acc){};
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
clear_tile(o_acc_unscaled);
}
@@ -1435,7 +1493,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc
// otherwise
auto& gemm1_acc = [&]() -> auto& {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
return o_acc_unscaled;
else
return o_acc;
@@ -1586,7 +1645,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// 1. P was scaled by 2^shift through exp2 shift trick
// 2. rowsum l was also scaled by 2^shift
// 3. Final O = sum(P*V) / l, so the 2^shift cancels out
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
tile_elementwise_inout(
[&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; },
@@ -1787,6 +1847,90 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
nblock_stride_kv_block_descale,
nhead_stride_kv_block_descale);
}
// Overload for PER_TOKEN_HEAD: Q/K per-token-per-head, V per-head
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout,
float sink_v,
const index_t max_page_table_idx,
const float* q_descale_per_token_ptr,
const float* k_descale_per_token_ptr,
const float* v_descale_per_head_ptr,
index_t stride_q_descale_token,
index_t nhead_stride_q_descale,
index_t nblock_stride_k_descale_page,
index_t stride_k_descale_token,
index_t nhead_stride_k_descale,
index_t nhead_stride_v_descale) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k,
stride_v,
page_stride_k,
page_stride_v,
dropout,
sink_v,
max_page_table_idx,
k_descale_per_token_ptr, // reused: k_descale_ptr slot
v_descale_per_head_ptr, // reused: v_descale_ptr slot
/*nblock_stride_kv_block_descale*/ 0,
/*nhead_stride_kv_block_descale*/ 0,
q_descale_per_token_ptr,
stride_q_descale_token,
nhead_stride_q_descale,
nblock_stride_k_descale_page,
stride_k_descale_token,
nhead_stride_k_descale,
nhead_stride_v_descale);
}
};
} // namespace ck_tile