mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[FMHA] Batch Prefill Support Improvements: Change KV Cache Layout & Large Page Size Support (#3442)
* add page_block_size parameter * add is_sglang_layout to parameters * add kv_offset_array_transform to batch async for page size 16 * add kv_last_page_lens to kernel * change kv layout to [num_total_pages, page_block_size, hdim] * format * - enable codegen of batch_prefill kernels - create new problem struct BlockFmhaBatchPrefillPipelineProblem for batch prefill kernels - generate different page sizes of batch prefill kernels (1, 16) * 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950 2. support page size 1024 * fix python format * change kv cache layout to [num_blocks, num_kv_heads, head_size/x, block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X] * 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values 2. Makes batch prefill kernel traits structures inherent from fmha fwd traits 3. Add some static check for Page size, vector size, hdim, ..., etc. * [Refactor] Replace is_sglang_layout with Enums for KV cache configuration Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single boolean. **Changes:** * Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`. * Updated Kernel, Pipeline, and Traits to template on these Enums. * Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`. * Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`. * Updated CodeGen scripts to support new parameters. This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations. * 1. remove batch prefill pipeline with sk_pad=false 2. correct some comments 3. add static assert to make sure v offsets is in same page within a tile. * fix vgpr spill count * remove unnecessary t2s functions * add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py * support linear kv cache layout * Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse kv_page_indices as a pointer of the lookup table. * 1. merge multiple transforms into single transform. 2. add static check to make sure vlayout is row-major. * move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs. * update changelog --------- Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
@@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
|
||||
static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
|
||||
static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct FmhaFwdEmptyKargs
|
||||
@@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct SglangPageTableKargs
|
||||
{
|
||||
const int32_t* kv_indptr;
|
||||
const int32_t* kv_page_indices;
|
||||
const int32_t* kv_last_page_lens;
|
||||
};
|
||||
|
||||
struct VllmPageTableKargs
|
||||
{
|
||||
const int32_t* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
};
|
||||
|
||||
using PageBlockTableKargs =
|
||||
std::conditional_t<kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
SglangPageTableKargs,
|
||||
VllmPageTableKargs>;
|
||||
|
||||
struct FmhaFwdCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
@@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
|
||||
int32_t num_total_pages;
|
||||
const int32_t* kv_indptr;
|
||||
const int32_t* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#else
|
||||
static constexpr ck_tile::index_t page_block_size = 1;
|
||||
#endif
|
||||
PageBlockTableKargs page_table;
|
||||
|
||||
float scale_s;
|
||||
|
||||
@@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
const PageBlockTableKargs& page_table,
|
||||
float scale_s,
|
||||
[[maybe_unused]] float scale_p,
|
||||
[[maybe_unused]] float scale_o,
|
||||
@@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_total_pages,
|
||||
reinterpret_cast<const int32_t*>(kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(kv_page_indices),
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
reinterpret_cast<const int32_t*>(kv_last_page_lens),
|
||||
page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
const PageBlockTableKargs& page_table,
|
||||
float scale_s,
|
||||
[[maybe_unused]] float scale_p,
|
||||
[[maybe_unused]] float scale_o,
|
||||
@@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_total_pages,
|
||||
reinterpret_cast<const int32_t*>(kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(kv_page_indices),
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
reinterpret_cast<const int32_t*>(kv_last_page_lens),
|
||||
page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
|
||||
#endif
|
||||
const index_t seqlen_k = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
|
||||
const int32_t num_page_blocks = page_end - page_start;
|
||||
const int32_t last_page_len = [&]() {
|
||||
if constexpr(kPageBlockSize == 1)
|
||||
return static_cast<int32_t>(kPageBlockSize);
|
||||
else
|
||||
return kargs.page_table.kv_last_page_lens[i_batch];
|
||||
}();
|
||||
return num_page_blocks > 0
|
||||
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
|
||||
last_page_len)
|
||||
: 0;
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
if(kargs.page_table.seqlen_k_ptr != nullptr)
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
|
||||
else
|
||||
return kargs.seqlen_k;
|
||||
}
|
||||
}();
|
||||
const int32_t* page_idx = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
return kargs.page_table.block_table_ptr +
|
||||
static_cast<long_index_t>(i_batch) *
|
||||
kargs.page_table.batch_stride_block_table;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
@@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
@@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
return;
|
||||
}
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
kargs.seqlen_k = seqlen_k;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
@@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
kargs.seqlen_k = seqlen_k;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
// Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
|
||||
// Logical View for Pipeline: (TotalSeqK, D)
|
||||
|
||||
// Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
|
||||
// PageBlockSize, kVectorSize)
|
||||
// Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages,
|
||||
kargs.hdim_q / kVectorSize,
|
||||
kargs.page_block_size,
|
||||
kVectorSize),
|
||||
make_tuple(
|
||||
kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
// Merge to (TotalSeqK, D) in a single transform:
|
||||
// physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
|
||||
auto k_dram_2d = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size)), // TotalSeqK
|
||||
make_merge_transform(
|
||||
make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
|
||||
static_cast<int32_t>(kVectorSize)))), // D
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
k_dram_2d,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
|
||||
// Logical View for Pipeline: (TotalSeqK, D)
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
// Merge to (TotalSeqK, D) in a single transform:
|
||||
// physical (Page, S, D) -> logical (TotalSeqK, D)
|
||||
auto k_dram_2d = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size)),
|
||||
make_pass_through_transform(kargs.hdim_q)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_2d,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
|
||||
// Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
|
||||
|
||||
// Define the naive physical view with 4D shape: (NumPages,
|
||||
// PageBlockSize/kVectorSize, HeadDim, kVectorSize)
|
||||
// Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size / kVectorSize,
|
||||
kargs.hdim_v,
|
||||
kVectorSize),
|
||||
make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
// Merge to (D, TotalSeqK) in a single transform:
|
||||
// physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
|
||||
auto v_dram_final = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v), // D
|
||||
make_merge_transform(make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size / kVectorSize,
|
||||
kVectorSize))), // TotalSeqK
|
||||
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_final,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
|
||||
// Logical View for Pipeline: (D, TotalSeqK)
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
|
||||
return pad_tensor_view(
|
||||
// Merge to (D, TotalSeqK) in a single transform:
|
||||
// physical (Page, S, D) -> logical (D, TotalSeqK)
|
||||
auto v_dram_final = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_merge_transform(
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size))),
|
||||
make_tuple(sequence<2>{}, sequence<0, 1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_final,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV_, kPadSeqLenK>{});
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
@@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
const index_t stride_k_for_pipeline =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
|
||||
? kVectorSize
|
||||
: kargs.stride_k;
|
||||
const index_t stride_v_for_pipeline =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
|
||||
? kargs.hdim_v
|
||||
: kargs.stride_v;
|
||||
|
||||
auto o_acc_tile = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
@@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
@@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
Reference in New Issue
Block a user