[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)

[CK] Add FP8 KV_BLOCKSCALE support for batch prefill
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Implement per-page K/V quantization for paged attention:
  - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum
  - Use exp2 shift trick to eliminate explicit P scaling overhead
- Prefetch physical pages offset for KV cache, overlaps with
computations

## Proposed changes

Please describe the motivation behind the pull request, whether it
enables a new feature or fixes a bug. If there are associated pull
requests or issues, please link them to the pull request.

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Jeff Huang
2026-02-04 23:26:20 +00:00
committed by assistant-librarian[bot]
parent 62fbda4d1e
commit 7b18f5fed2
8 changed files with 559 additions and 105 deletions

2
Jenkinsfile vendored
View File

@@ -1784,7 +1784,7 @@ pipeline {
agent{ label rocmnode("gfx90a") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCK_CXX_STANDARD="17" """
execute_args = build_client_examples_and_codegen_tests("gfx90a")
execute_args = build_client_examples("gfx90a")
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')

View File

@@ -78,12 +78,14 @@ QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
}
BIAS_MAP = {

View File

@@ -677,7 +677,7 @@ class KernelComponentFactory:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
["pertensor", "kv_blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
@@ -740,6 +740,10 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,

View File

@@ -602,6 +602,13 @@ struct fmha_batch_prefill_args
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
template <typename FmhaKernel>
@@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
else
{ // create batch mode kernel arguments
@@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
}();

View File

@@ -14,9 +14,10 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
};
struct quant_scale_info
@@ -31,6 +32,8 @@ struct quant_scale_info
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
}
static quant_scale_info decode(std::string str)
@@ -48,6 +51,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);

View File

@@ -10,9 +10,10 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE,
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
};
template <BlockAttentionQuantScaleEnum>

View File

@@ -10,6 +10,7 @@
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <cassert>
#include <string>
#include <type_traits>
#include <utility>
@@ -185,13 +186,45 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdCommonQScaleKargs
// PERTENSOR: Q/K/V all use per-tensor descales
struct FmhaFwdPerTensorQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};
// KV_BLOCKSCALE: Q per-tensor, K/V per-page descales
// K descale: [num_block, num_kv_head], V descale: [num_block, num_kv_head]
struct FmhaFwdKVBlockScaleKargs
{
const void* q_descale_ptr = nullptr; // Per-tensor Q descale
const void* k_descale_ptr = nullptr; // [num_block, num_kv_head]
const void* v_descale_ptr = nullptr; // [num_block, num_kv_head]
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
// Helper template to select QScale Kargs type based on QScaleEnum
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
struct GetQScaleKargs
{
using type = EmptyType;
};
template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PERTENSOR, EmptyType>
{
using type = FmhaFwdPerTensorQScaleKargs;
};
template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE, EmptyType>
{
using type = FmhaFwdKVBlockScaleKargs;
};
struct FmhaFwdDropoutSeedOffset
{
template <typename T>
@@ -255,9 +288,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -276,9 +307,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -348,7 +377,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -495,7 +534,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -563,6 +604,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1157,11 +1206,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
assert(kargs.q_descale_ptr != nullptr);
assert(kargs.k_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
return kargs.scale_s * q_descale * k_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// Q is per-tensor, K is per-page (handled in pipeline)
assert(kargs.q_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
return kargs.scale_s * q_descale;
}
else
{
return kargs.scale_s;
@@ -1194,6 +1252,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
assert(kargs.v_descale_ptr != nullptr);
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
@@ -1237,6 +1296,39 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
dropout,
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
assert(kargs.k_descale_ptr != nullptr);
assert(kargs.v_descale_ptr != nullptr);
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value,
k_descale_ptr,
v_descale_ptr,
kargs.nblock_stride_kv_block_descale,
kargs.nhead_stride_kv_block_descale);
}
else
{
return FmhaPipeline{}(q_dram_window,

View File

@@ -7,13 +7,21 @@
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename OffsetVecType,
// Load physical pages from page_idx lookup table.
// K cache: per-token lookup (each k0 may have different page_id)
// V cache: depends on whether V tile crosses pages
// - Crosses pages: per-token lookup
// - Single page: lane0 lookup once, broadcast to all
// Output: physical_pages array with kLoopCount elements
template <typename IndexArrayType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
@@ -22,14 +30,11 @@ template <typename OffsetVecType,
index_t kLoopStride,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
const index_t& stride_token,
const index_t& stride_page_block,
const CoordVecType& coord_vec,
OffsetVecType& kv_offset_vec,
index_t global_seq_offset = 0)
index_t kN0>
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
const CoordVecType& coord_vec,
index_t global_seq_offset,
IndexArrayType& physical_pages)
{
static constexpr index_t kLog2PageSize = [] {
index_t shift = 0;
@@ -42,18 +47,16 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
return shift;
}();
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
const index_t& thread_coord_start = coord_vec[kCoordAxis];
if constexpr(kIsKcache)
{
// for k offsets
// K cache: per-token lookup (all tokens may be on different pages)
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
kv_offset_vec[k0] = static_cast<long_index_t>(page_idx[page_id]) * stride_page_block +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
});
}
else
@@ -71,11 +74,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[global_token_idx]) * stride_page_block;
kv_offset_vec[k0] = page_base_offset;
physical_pages[k0] = page_idx[global_token_idx];
});
}
else if constexpr(kVTileCrossesPages)
@@ -85,70 +84,131 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[page_id]) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
// address pattern.
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_base_offset +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
}
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
});
}
else // !kVTileCrossesPages
else
{
// V tile is fully contained in one page, so page_id is shared.
// Use lane0 to compute page_id once and broadcast page_base_offset.
// V tile fully contained in one page: lane0 lookup, broadcast to all
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
const index_t shared_physical_page = page_idx[lane0_page_id];
static_for<0, kLoopCount, 1>{}(
[&](auto k0) { physical_pages[k0] = shared_physical_page; });
}
}
}
// kv_offset_array_transform: Converts logical token indices to physical memory offsets
// for paged KV cache access.
//
// This version uses pre-loaded physical_pages array from load_physical_pages().
// Benefits:
// - page_idx is read only once (by load_physical_pages)
// - physical_pages can be prefetched before GEMM to hide memory latency
// - physical_pages can be reused for descale lookup (KV_BLOCKSCALE)
//
// Template parameters:
// - kCoordAxis: Which axis of coord_vec contains the thread's token coordinate
// - kPageBlockSize: Number of tokens per page (must be power of 2)
// - kLoopStart/kLoopCount/kLoopStride: Loop iteration parameters for static_for
// - kKVMemoryLayout: VECTORIZED_LAYOUT or LINEAR_LAYOUT
// - kIsKcache: true for K cache, false for V cache
// - kN0: Tile size in N dimension (used for page crossing detection)
// - kVectorSize: Vector size for vectorized layout (e.g., 8 for fp8)
//
// Memory layout for V cache:
// LINEAR_LAYOUT: [page, token_in_page, head_dim]
// VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
//
template <typename IndexArrayType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
index_t kLoopStart,
index_t kLoopCount,
index_t kLoopStride,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
const index_t& stride_token,
const index_t& stride_page_block,
const CoordVecType& coord_vec,
IndexArrayType& kv_offset_vec,
index_t global_seq_offset = 0)
{
static constexpr index_t kLog2PageSize = [] {
index_t shift = 0;
index_t val = kPageBlockSize;
while(val > 1)
{
val >>= 1;
shift++;
}
return shift;
}();
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
{
// K cache: per-token lookup
// Each token may be on a different page, so we use physical_pages[k0] for each.
// Offset = physical_page * stride_page_block + token_idx_in_page * stride_token
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] = static_cast<long_index_t>(physical_page) * stride_page_block +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
});
}
else // !kVTileCrossesPages
{
// V cache: use physical_pages[k0] for each token
// physical_pages was already populated correctly by load_physical_pages(), handling:
// - page_size=1: page_idx maps token_idx -> physical_page directly
// - V tile crosses pages: per-token page lookup
// - V tile in single page: lane0 lookup with broadcast to all lanes
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
static_cast<long_index_t>(physical_page) * stride_page_block;
static_for<0, kLoopCount, 1>{}([&](auto k0) {
// kLoopStride allows non-unit token spacing in the tile distribution.
const index_t token_idx_in_page =
(global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) &
kInPageOffsetMask;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout offset calculation:
// Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
// Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) +
// (token % kVectorSize)
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout offset
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
// Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) +
// (token_idx_in_page % kVectorSize)
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_base_offset +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
}
});
}
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else // LINEAR_LAYOUT
{
// Linear layout: [page, token_in_page, head_dim]
// Offset = page_base + token_idx_in_page * stride_token
kv_offset_vec[k0] =
page_base_offset + static_cast<long_index_t>(token_idx_in_page) * stride_token;
}
});
}
}
@@ -209,6 +269,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
// For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
// This avoids explicit P *= scale_p and v_descale /= scale_p operations
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -341,8 +407,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout,
const float sink_v) const
const float sink_v,
// KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE)
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
index_t nblock_stride_kv_block_descale = 0,
index_t nhead_stride_kv_block_descale = 0) const
{
// KV_BLOCKSCALE requires page_block_size >= kN0 to ensure
// all tokens in a main loop iteration belong to the same page
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0");
}
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
@@ -494,6 +572,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
index_t current_seq_k = seqlen_k_start;
// Load physical pages first, then compute offsets.
// k_physical_pages can be reused for descale lookup later.
statically_indexed_array<index_t, NRepeat> k_physical_pages{};
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
@@ -505,7 +598,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
true,
kN0,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
@@ -644,6 +737,52 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
"V page-index Y dim must be valid");
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
// V physical pages array for use with kv_offset_array_transform
// For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner
statically_indexed_array<index_t, V_PageIdxRepeat> v_physical_pages{};
// Prefetch V physical pages - can be called early to hide buffer load latency
auto prefetch_v_physical_pages = [&](auto k_loop_start) {
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
if constexpr(V_KIterOuter > 1)
{
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
// Load physical pages for this k2 slice into the appropriate portion of array
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2{};
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart + k2.value * V_KLanes * V_KIterInner,
V_KIterInner,
1,
kKVMemoryLayout,
false,
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2);
// Copy to merged array
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_physical_pages[idx] = v_physical_pages_k2[k1];
});
});
}
else
{
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart,
V_KIterInner,
1,
kKVMemoryLayout,
false,
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages);
}
};
// Update V offsets using pre-loaded physical pages
auto update_v_offsets = [&](auto k_loop_start) {
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
// For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice
@@ -653,6 +792,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
// Extract physical pages for this k2 slice
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2;
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_physical_pages_k2[k1] = v_physical_pages[idx];
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
@@ -663,8 +809,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k);
kVectorSize>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_offsets[idx] = v_offsets_k2[k1];
@@ -684,9 +835,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
}
};
// Prefetch V physical pages early to hide buffer load latency
prefetch_v_physical_pages(number<0>{});
update_v_offsets(number<0>{});
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -717,6 +871,41 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// main loop
do
{
// KV_BLOCKSCALE: load per-page K/V descale factors
// Uses k_physical_pages[0] from load_physical_pages to avoid redundant page_idx reads.
// Assumes kPageBlockSize >= kN0, so all tokens in one main loop iteration belong to
// the same page (single scale pair).
//
// TODO: Cross-page KV_BLOCKSCALE support
// Currently only supports kPageBlockSize >= kN0 (all tokens in tile on same page).
// To support smaller page sizes (cross-page tiles), need:
//
// 1. K descale: Load per-token k_descale_vec[NRepeat] based on k_physical_pages[k0]
// - After GEMM0 (S = Q × K^T), apply column-wise scaling: S[:,j] *= k_descale[j]
// - Requires modifying s_acc_element_func to accept column index
//
// 2. V descale: Load per-token v_descale_vec[V_PageIdxRepeat] based on
// v_physical_pages[k0]
// - Before GEMM1 (O = P × V), apply row-wise scaling to P: P[i,j] *= v_descale[j]
// - Or pre-scale V in LDS (more complex)
//
// 3. K and V may be on different pages for the same token index, so need separate
// lookups
//
[[maybe_unused]] float k_descale = 1.0f;
[[maybe_unused]] float v_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
const index_t scale_offset =
k_physical_pages[number<0>{}] * nblock_stride_kv_block_descale +
block_indices.kv_head_idx * nhead_stride_kv_block_descale;
k_descale = k_descale_ptr[scale_offset];
v_descale = v_descale_ptr[scale_offset];
}
// Prefetch V physical pages early - overlaps with GEMM0 computation
prefetch_v_physical_pages(number<kK1>{});
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
@@ -763,9 +952,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1);
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
// V physical pages already prefetched before GEMM0
update_v_offsets(number<kK1>{});
v_dram_window.update_page_idx(v_offsets);
// KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
}
const auto p = [&]() {
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
@@ -875,6 +1071,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
// Prefetch V physical pages early - overlaps with softmax computation
if constexpr(k1_loops > 1)
{
prefetch_v_physical_pages(number<2 * kK1>{});
}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
@@ -953,7 +1156,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
// For KV_BLOCKSCALE: precompute (m - shift) once per row
// exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift
// This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
@@ -961,13 +1178,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
}
else
{
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
}
else
{
@@ -1049,6 +1266,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}();
// STAGE 3, KV gemm
// KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale
auto o_acc_unscaled = decltype(o_acc){};
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
clear_tile(o_acc_unscaled);
}
// Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc
// otherwise
auto& gemm1_acc = [&]() -> auto& {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
return o_acc_unscaled;
else
return o_acc;
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
@@ -1056,11 +1289,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
// Update V offsets using previously prefetched physical pages
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
}
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
if constexpr(i_k1 + 1 < k1_loops - 1)
{
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
}
block_sync_lds();
gemm_1(o_acc,
gemm_1(gemm1_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
@@ -1104,6 +1345,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
// KV_BLOCKSCALE: reload physical pages for the new tile
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
@@ -1115,7 +1368,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
true,
kN0,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
@@ -1131,13 +1384,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
block_sync_lds();
gemm_1(
o_acc,
gemm1_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
// KV_BLOCKSCALE: apply v_descale and accumulate o_acc_unscaled into o_acc
// Note: No division by scale_p needed because:
// 1. P was scaled by 2^shift through exp2 shift trick
// 2. rowsum l was also scaled by 2^shift
// 3. Final O = sum(P*V) / l, so the 2^shift cancels out
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
tile_elementwise_inout(
[&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; },
o_acc,
o_acc_unscaled);
}
} while(i_total_loops < num_total_loop);
// store lse
@@ -1257,6 +1523,77 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
dropout,
sink_v);
}
// Overload for KV_BLOCKSCALE: K/V descale is per-page
// This is a convenience overload that forwards to the main operator() with kv_scale parameters
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout,
float sink_v,
const float* k_descale_ptr,
const float* v_descale_ptr,
index_t nblock_stride_kv_block_descale,
index_t nhead_stride_kv_block_descale) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k,
stride_v,
page_stride_k,
page_stride_v,
dropout,
sink_v,
k_descale_ptr,
v_descale_ptr,
nblock_stride_kv_block_descale,
nhead_stride_kv_block_descale);
}
};
} // namespace ck_tile