mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)
[CK] Add FP8 KV_BLOCKSCALE support for batch prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement per-page K/V quantization for paged attention: - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum - Use exp2 shift trick to eliminate explicit P scaling overhead - Prefetch physical pages offset for KV cache, overlaps with computations ## Proposed changes Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
62fbda4d1e
commit
7b18f5fed2
2
Jenkinsfile
vendored
2
Jenkinsfile
vendored
@@ -1784,7 +1784,7 @@ pipeline {
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCK_CXX_STANDARD="17" """
|
||||
execute_args = build_client_examples_and_codegen_tests("gfx90a")
|
||||
execute_args = build_client_examples("gfx90a")
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
|
||||
@@ -78,12 +78,14 @@ QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -677,7 +677,7 @@ class KernelComponentFactory:
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
@@ -740,6 +740,10 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
|
||||
@@ -602,6 +602,13 @@ struct fmha_batch_prefill_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
|
||||
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
|
||||
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
|
||||
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
@@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -31,6 +32,8 @@ struct quant_scale_info
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::kv_blockscale)
|
||||
os << "kvbs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -48,6 +51,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else if(str == "kvbs" || str == "3")
|
||||
{
|
||||
info.type = quant_scale_enum::kv_blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
@@ -10,9 +10,10 @@ namespace ck_tile {
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE,
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE = 2,
|
||||
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -185,13 +186,45 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
// PERTENSOR: Q/K/V all use per-tensor descales
|
||||
struct FmhaFwdPerTensorQScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
// KV_BLOCKSCALE: Q per-tensor, K/V per-page descales
|
||||
// K descale: [num_block, num_kv_head], V descale: [num_block, num_kv_head]
|
||||
struct FmhaFwdKVBlockScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr; // Per-tensor Q descale
|
||||
const void* k_descale_ptr = nullptr; // [num_block, num_kv_head]
|
||||
const void* v_descale_ptr = nullptr; // [num_block, num_kv_head]
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
// Helper template to select QScale Kargs type based on QScaleEnum
|
||||
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
|
||||
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
|
||||
struct GetQScaleKargs
|
||||
{
|
||||
using type = EmptyType;
|
||||
};
|
||||
|
||||
template <typename EmptyType>
|
||||
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PERTENSOR, EmptyType>
|
||||
{
|
||||
using type = FmhaFwdPerTensorQScaleKargs;
|
||||
};
|
||||
|
||||
template <typename EmptyType>
|
||||
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE, EmptyType>
|
||||
{
|
||||
using type = FmhaFwdKVBlockScaleKargs;
|
||||
};
|
||||
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
template <typename T>
|
||||
@@ -255,9 +288,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -276,9 +307,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -348,7 +377,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0,
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
|
||||
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -495,7 +534,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0,
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -563,6 +604,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
|
||||
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -1157,11 +1206,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
|
||||
return kargs.scale_s * q_descale * k_descale;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
// Q is per-tensor, K is per-page (handled in pipeline)
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
return kargs.scale_s * q_descale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
@@ -1194,6 +1252,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
@@ -1237,6 +1296,39 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
|
||||
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout,
|
||||
sink_value,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.nblock_stride_kv_block_descale,
|
||||
kargs.nhead_stride_kv_block_descale);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
|
||||
@@ -7,13 +7,21 @@
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename OffsetVecType,
|
||||
|
||||
// Load physical pages from page_idx lookup table.
|
||||
// K cache: per-token lookup (each k0 may have different page_id)
|
||||
// V cache: depends on whether V tile crosses pages
|
||||
// - Crosses pages: per-token lookup
|
||||
// - Single page: lane0 lookup once, broadcast to all
|
||||
// Output: physical_pages array with kLoopCount elements
|
||||
template <typename IndexArrayType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
@@ -22,14 +30,11 @@ template <typename OffsetVecType,
|
||||
index_t kLoopStride,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kN0,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
const index_t& stride_token,
|
||||
const index_t& stride_page_block,
|
||||
const CoordVecType& coord_vec,
|
||||
OffsetVecType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
index_t kN0>
|
||||
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
const CoordVecType& coord_vec,
|
||||
index_t global_seq_offset,
|
||||
IndexArrayType& physical_pages)
|
||||
{
|
||||
static constexpr index_t kLog2PageSize = [] {
|
||||
index_t shift = 0;
|
||||
@@ -42,18 +47,16 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
return shift;
|
||||
}();
|
||||
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// for k offsets
|
||||
// K cache: per-token lookup (all tokens may be on different pages)
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(page_idx[page_id]) * stride_page_block +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -71,11 +74,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(page_idx[global_token_idx]) * stride_page_block;
|
||||
|
||||
kv_offset_vec[k0] = page_base_offset;
|
||||
physical_pages[k0] = page_idx[global_token_idx];
|
||||
});
|
||||
}
|
||||
else if constexpr(kVTileCrossesPages)
|
||||
@@ -85,70 +84,131 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(page_idx[page_id]) * stride_page_block;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
|
||||
// address pattern.
|
||||
const long_index_t token_offset =
|
||||
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
|
||||
(stride_token * kVectorSize)) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
}
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
});
|
||||
}
|
||||
else // !kVTileCrossesPages
|
||||
else
|
||||
{
|
||||
// V tile is fully contained in one page, so page_id is shared.
|
||||
// Use lane0 to compute page_id once and broadcast page_base_offset.
|
||||
// V tile fully contained in one page: lane0 lookup, broadcast to all
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
const index_t shared_physical_page = page_idx[lane0_page_id];
|
||||
|
||||
static_for<0, kLoopCount, 1>{}(
|
||||
[&](auto k0) { physical_pages[k0] = shared_physical_page; });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// kv_offset_array_transform: Converts logical token indices to physical memory offsets
|
||||
// for paged KV cache access.
|
||||
//
|
||||
// This version uses pre-loaded physical_pages array from load_physical_pages().
|
||||
// Benefits:
|
||||
// - page_idx is read only once (by load_physical_pages)
|
||||
// - physical_pages can be prefetched before GEMM to hide memory latency
|
||||
// - physical_pages can be reused for descale lookup (KV_BLOCKSCALE)
|
||||
//
|
||||
// Template parameters:
|
||||
// - kCoordAxis: Which axis of coord_vec contains the thread's token coordinate
|
||||
// - kPageBlockSize: Number of tokens per page (must be power of 2)
|
||||
// - kLoopStart/kLoopCount/kLoopStride: Loop iteration parameters for static_for
|
||||
// - kKVMemoryLayout: VECTORIZED_LAYOUT or LINEAR_LAYOUT
|
||||
// - kIsKcache: true for K cache, false for V cache
|
||||
// - kN0: Tile size in N dimension (used for page crossing detection)
|
||||
// - kVectorSize: Vector size for vectorized layout (e.g., 8 for fp8)
|
||||
//
|
||||
// Memory layout for V cache:
|
||||
// LINEAR_LAYOUT: [page, token_in_page, head_dim]
|
||||
// VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
|
||||
//
|
||||
template <typename IndexArrayType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
index_t kLoopStart,
|
||||
index_t kLoopCount,
|
||||
index_t kLoopStride,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kN0,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
|
||||
const index_t& stride_token,
|
||||
const index_t& stride_page_block,
|
||||
const CoordVecType& coord_vec,
|
||||
IndexArrayType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
{
|
||||
static constexpr index_t kLog2PageSize = [] {
|
||||
index_t shift = 0;
|
||||
index_t val = kPageBlockSize;
|
||||
while(val > 1)
|
||||
{
|
||||
val >>= 1;
|
||||
shift++;
|
||||
}
|
||||
return shift;
|
||||
}();
|
||||
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// K cache: per-token lookup
|
||||
// Each token may be on a different page, so we use physical_pages[k0] for each.
|
||||
// Offset = physical_page * stride_page_block + token_idx_in_page * stride_token
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(physical_page) * stride_page_block +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
});
|
||||
}
|
||||
else // !kVTileCrossesPages
|
||||
{
|
||||
// V cache: use physical_pages[k0] for each token
|
||||
// physical_pages was already populated correctly by load_physical_pages(), handling:
|
||||
// - page_size=1: page_idx maps token_idx -> physical_page directly
|
||||
// - V tile crosses pages: per-token page lookup
|
||||
// - V tile in single page: lane0 lookup with broadcast to all lanes
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
|
||||
static_cast<long_index_t>(physical_page) * stride_page_block;
|
||||
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
// kLoopStride allows non-unit token spacing in the tile distribution.
|
||||
const index_t token_idx_in_page =
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) &
|
||||
kInPageOffsetMask;
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout offset calculation:
|
||||
// Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
|
||||
// Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) +
|
||||
// (token % kVectorSize)
|
||||
const long_index_t token_offset =
|
||||
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
|
||||
(stride_token * kVectorSize)) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout offset
|
||||
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) +
|
||||
// (token_idx_in_page % kVectorSize)
|
||||
|
||||
const long_index_t token_offset =
|
||||
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
|
||||
(stride_token * kVectorSize)) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else // LINEAR_LAYOUT
|
||||
{
|
||||
// Linear layout: [page, token_in_page, head_dim]
|
||||
// Offset = page_base + token_idx_in_page * stride_token
|
||||
kv_offset_vec[k0] =
|
||||
page_base_offset + static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,6 +269,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
// For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
// This avoids explicit P *= scale_p and v_descale /= scale_p operations
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -341,8 +407,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
const float sink_v,
|
||||
// KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE)
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
index_t nblock_stride_kv_block_descale = 0,
|
||||
index_t nhead_stride_kv_block_descale = 0) const
|
||||
{
|
||||
// KV_BLOCKSCALE requires page_block_size >= kN0 to ensure
|
||||
// all tokens in a main loop iteration belong to the same page
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0");
|
||||
}
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
@@ -494,6 +572,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
index_t current_seq_k = seqlen_k_start;
|
||||
|
||||
// Load physical pages first, then compute offsets.
|
||||
// k_physical_pages can be reused for descale lookup later.
|
||||
statically_indexed_array<index_t, NRepeat> k_physical_pages{};
|
||||
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
@@ -505,7 +598,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
|
||||
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
@@ -644,6 +737,52 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
"V page-index Y dim must be valid");
|
||||
|
||||
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
|
||||
// V physical pages array for use with kv_offset_array_transform
|
||||
// For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner
|
||||
statically_indexed_array<index_t, V_PageIdxRepeat> v_physical_pages{};
|
||||
|
||||
// Prefetch V physical pages - can be called early to hide buffer load latency
|
||||
auto prefetch_v_physical_pages = [&](auto k_loop_start) {
|
||||
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
|
||||
if constexpr(V_KIterOuter > 1)
|
||||
{
|
||||
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
|
||||
// Load physical pages for this k2 slice into the appropriate portion of array
|
||||
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2{};
|
||||
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
kPageBlockSize,
|
||||
kLoopStart + k2.value * V_KLanes * V_KIterInner,
|
||||
V_KIterInner,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2);
|
||||
|
||||
// Copy to merged array
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
v_physical_pages[idx] = v_physical_pages_k2[k1];
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
kPageBlockSize,
|
||||
kLoopStart,
|
||||
V_KIterInner,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages);
|
||||
}
|
||||
};
|
||||
|
||||
// Update V offsets using pre-loaded physical pages
|
||||
auto update_v_offsets = [&](auto k_loop_start) {
|
||||
constexpr index_t kLoopStart = decltype(k_loop_start)::value;
|
||||
// For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice
|
||||
@@ -653,6 +792,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
|
||||
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
|
||||
// Extract physical pages for this k2 slice
|
||||
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2;
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
v_physical_pages_k2[k1] = v_physical_pages[idx];
|
||||
});
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
@@ -663,8 +809,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k);
|
||||
kVectorSize>(v_physical_pages_k2,
|
||||
stride_v,
|
||||
page_stride_v,
|
||||
v_coord,
|
||||
v_offsets_k2,
|
||||
current_seq_k);
|
||||
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
v_offsets[idx] = v_offsets_k2[k1];
|
||||
@@ -684,9 +835,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
}
|
||||
};
|
||||
|
||||
// Prefetch V physical pages early to hide buffer load latency
|
||||
prefetch_v_physical_pages(number<0>{});
|
||||
update_v_offsets(number<0>{});
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -717,6 +871,41 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
// KV_BLOCKSCALE: load per-page K/V descale factors
|
||||
// Uses k_physical_pages[0] from load_physical_pages to avoid redundant page_idx reads.
|
||||
// Assumes kPageBlockSize >= kN0, so all tokens in one main loop iteration belong to
|
||||
// the same page (single scale pair).
|
||||
//
|
||||
// TODO: Cross-page KV_BLOCKSCALE support
|
||||
// Currently only supports kPageBlockSize >= kN0 (all tokens in tile on same page).
|
||||
// To support smaller page sizes (cross-page tiles), need:
|
||||
//
|
||||
// 1. K descale: Load per-token k_descale_vec[NRepeat] based on k_physical_pages[k0]
|
||||
// - After GEMM0 (S = Q × K^T), apply column-wise scaling: S[:,j] *= k_descale[j]
|
||||
// - Requires modifying s_acc_element_func to accept column index
|
||||
//
|
||||
// 2. V descale: Load per-token v_descale_vec[V_PageIdxRepeat] based on
|
||||
// v_physical_pages[k0]
|
||||
// - Before GEMM1 (O = P × V), apply row-wise scaling to P: P[i,j] *= v_descale[j]
|
||||
// - Or pre-scale V in LDS (more complex)
|
||||
//
|
||||
// 3. K and V may be on different pages for the same token index, so need separate
|
||||
// lookups
|
||||
//
|
||||
[[maybe_unused]] float k_descale = 1.0f;
|
||||
[[maybe_unused]] float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
const index_t scale_offset =
|
||||
k_physical_pages[number<0>{}] * nblock_stride_kv_block_descale +
|
||||
block_indices.kv_head_idx * nhead_stride_kv_block_descale;
|
||||
k_descale = k_descale_ptr[scale_offset];
|
||||
v_descale = v_descale_ptr[scale_offset];
|
||||
}
|
||||
|
||||
// Prefetch V physical pages early - overlaps with GEMM0 computation
|
||||
prefetch_v_physical_pages(number<kK1>{});
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
if constexpr(k0_loops > 1)
|
||||
@@ -763,9 +952,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
// V physical pages already prefetched before GEMM0
|
||||
update_v_offsets(number<kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
// KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
@@ -875,6 +1071,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
// Prefetch V physical pages early - overlaps with softmax computation
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
prefetch_v_physical_pages(number<2 * kK1>{});
|
||||
}
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
@@ -953,7 +1156,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
// For KV_BLOCKSCALE: precompute (m - shift) once per row
|
||||
// exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift
|
||||
// This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -961,13 +1178,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1049,6 +1266,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
// KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale
|
||||
auto o_acc_unscaled = decltype(o_acc){};
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
clear_tile(o_acc_unscaled);
|
||||
}
|
||||
|
||||
// Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc
|
||||
// otherwise
|
||||
auto& gemm1_acc = [&]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
return o_acc_unscaled;
|
||||
else
|
||||
return o_acc;
|
||||
}();
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
@@ -1056,11 +1289,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
// Update V offsets using previously prefetched physical pages
|
||||
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
|
||||
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
|
||||
if constexpr(i_k1 + 1 < k1_loops - 1)
|
||||
{
|
||||
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
gemm_1(gemm1_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
@@ -1104,6 +1345,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
// KV_BLOCKSCALE: reload physical pages for the new tile
|
||||
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
@@ -1115,7 +1368,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
@@ -1131,13 +1384,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
gemm1_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
|
||||
// KV_BLOCKSCALE: apply v_descale and accumulate o_acc_unscaled into o_acc
|
||||
// Note: No division by scale_p needed because:
|
||||
// 1. P was scaled by 2^shift through exp2 shift trick
|
||||
// 2. rowsum l was also scaled by 2^shift
|
||||
// 3. Final O = sum(P*V) / l, so the 2^shift cancels out
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; },
|
||||
o_acc,
|
||||
o_acc_unscaled);
|
||||
}
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
@@ -1257,6 +1523,77 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
|
||||
// Overload for KV_BLOCKSCALE: K/V descale is per-page
|
||||
// This is a convenience overload that forwards to the main operator() with kv_scale parameters
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout,
|
||||
float sink_v,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
index_t nblock_stride_kv_block_descale,
|
||||
index_t nhead_stride_kv_block_descale) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout,
|
||||
sink_v,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
nblock_stride_kv_block_descale,
|
||||
nhead_stride_kv_block_descale);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user