mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[FMHA] Batch Prefill Support Improvements: Change KV Cache Layout & Large Page Size Support (#3442)
* add page_block_size parameter
* add is_sglang_layout to parameters
* add kv_offset_array_transform to batch async for page size 16
* add kv_last_page_lens to kernel
* change kv layout to [num_total_pages, page_block_size, hdim]
* format
* - enable codegen of batch_prefill kernels
- create new problem struct BlockFmhaBatchPrefillPipelineProblem for
batch prefill kernels
- generate different page sizes of batch prefill kernels (1, 16)
* 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950
2. support page size 1024
* fix python format
* change kv cache layout to [num_blocks, num_kv_heads, head_size/x,
block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X]
* 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values
2. Makes batch prefill kernel traits structures inherent from fmha fwd
traits
3. Add some static check for Page size, vector size, hdim, ..., etc.
* [Refactor] Replace is_sglang_layout with Enums for KV cache configuration
Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single
boolean.
**Changes:**
* Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`.
* Updated Kernel, Pipeline, and Traits to template on these Enums.
* Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`.
* Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`.
* Updated CodeGen scripts to support new parameters.
This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations.
* 1. remove batch prefill pipeline with sk_pad=false
2. correct some comments
3. add static assert to make sure v offsets is in same page within a tile.
* fix vgpr spill count
* remove unnecessary t2s functions
* add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py
* support linear kv cache layout
* Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse
kv_page_indices as a pointer of the lookup table.
* 1. merge multiple transforms into single transform.
2. add static check to make sure vlayout is row-major.
* move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs.
* update changelog
---------
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: cc75a1dc5f]
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// KV cache memory layout selector.
|
||||
//
|
||||
// Layout summary (kVectorSize = 16 / sizeof(KDataType)):
|
||||
// - VECTORIZED_LAYOUT (swizzled):
|
||||
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize]
|
||||
// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// - LINEAR_LAYOUT:
|
||||
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
|
||||
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
|
||||
enum class BlockAttentionKVCacheMemoryLayoutEnum
|
||||
{
|
||||
VECTORIZED_LAYOUT = 0,
|
||||
LINEAR_LAYOUT = 1,
|
||||
};
|
||||
|
||||
// KV cache lookup table layout selector.
|
||||
// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq]
|
||||
// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1])
|
||||
enum class BlockAttentionKVCacheLookupTableEnum
|
||||
{
|
||||
VLLM_BLOCK_TABLE_2D = 0,
|
||||
SGLANG_PAGE_TABLE_1D = 1,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
@@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
|
||||
static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
|
||||
static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct FmhaFwdEmptyKargs
|
||||
@@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct SglangPageTableKargs
|
||||
{
|
||||
const int32_t* kv_indptr;
|
||||
const int32_t* kv_page_indices;
|
||||
const int32_t* kv_last_page_lens;
|
||||
};
|
||||
|
||||
struct VllmPageTableKargs
|
||||
{
|
||||
const int32_t* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
};
|
||||
|
||||
using PageBlockTableKargs =
|
||||
std::conditional_t<kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
SglangPageTableKargs,
|
||||
VllmPageTableKargs>;
|
||||
|
||||
struct FmhaFwdCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
@@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
|
||||
int32_t num_total_pages;
|
||||
const int32_t* kv_indptr;
|
||||
const int32_t* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#else
|
||||
static constexpr ck_tile::index_t page_block_size = 1;
|
||||
#endif
|
||||
PageBlockTableKargs page_table;
|
||||
|
||||
float scale_s;
|
||||
|
||||
@@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
const PageBlockTableKargs& page_table,
|
||||
float scale_s,
|
||||
[[maybe_unused]] float scale_p,
|
||||
[[maybe_unused]] float scale_o,
|
||||
@@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_total_pages,
|
||||
reinterpret_cast<const int32_t*>(kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(kv_page_indices),
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
reinterpret_cast<const int32_t*>(kv_last_page_lens),
|
||||
page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
const PageBlockTableKargs& page_table,
|
||||
float scale_s,
|
||||
[[maybe_unused]] float scale_p,
|
||||
[[maybe_unused]] float scale_o,
|
||||
@@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_total_pages,
|
||||
reinterpret_cast<const int32_t*>(kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(kv_page_indices),
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
reinterpret_cast<const int32_t*>(kv_last_page_lens),
|
||||
page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
|
||||
#endif
|
||||
const index_t seqlen_k = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
|
||||
const int32_t num_page_blocks = page_end - page_start;
|
||||
const int32_t last_page_len = [&]() {
|
||||
if constexpr(kPageBlockSize == 1)
|
||||
return static_cast<int32_t>(kPageBlockSize);
|
||||
else
|
||||
return kargs.page_table.kv_last_page_lens[i_batch];
|
||||
}();
|
||||
return num_page_blocks > 0
|
||||
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
|
||||
last_page_len)
|
||||
: 0;
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
if(kargs.page_table.seqlen_k_ptr != nullptr)
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
|
||||
else
|
||||
return kargs.seqlen_k;
|
||||
}
|
||||
}();
|
||||
const int32_t* page_idx = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
return kargs.page_table.block_table_ptr +
|
||||
static_cast<long_index_t>(i_batch) *
|
||||
kargs.page_table.batch_stride_block_table;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
@@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
@@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
return;
|
||||
}
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
kargs.seqlen_k = seqlen_k;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
@@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
kargs.seqlen_k = seqlen_k;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
// Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
|
||||
// Logical View for Pipeline: (TotalSeqK, D)
|
||||
|
||||
// Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
|
||||
// PageBlockSize, kVectorSize)
|
||||
// Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages,
|
||||
kargs.hdim_q / kVectorSize,
|
||||
kargs.page_block_size,
|
||||
kVectorSize),
|
||||
make_tuple(
|
||||
kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
// Merge to (TotalSeqK, D) in a single transform:
|
||||
// physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
|
||||
auto k_dram_2d = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size)), // TotalSeqK
|
||||
make_merge_transform(
|
||||
make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
|
||||
static_cast<int32_t>(kVectorSize)))), // D
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
k_dram_2d,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
|
||||
// Logical View for Pipeline: (TotalSeqK, D)
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
// Merge to (TotalSeqK, D) in a single transform:
|
||||
// physical (Page, S, D) -> logical (TotalSeqK, D)
|
||||
auto k_dram_2d = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size)),
|
||||
make_pass_through_transform(kargs.hdim_q)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_2d,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
|
||||
// Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
|
||||
|
||||
// Define the naive physical view with 4D shape: (NumPages,
|
||||
// PageBlockSize/kVectorSize, HeadDim, kVectorSize)
|
||||
// Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size / kVectorSize,
|
||||
kargs.hdim_v,
|
||||
kVectorSize),
|
||||
make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
// Merge to (D, TotalSeqK) in a single transform:
|
||||
// physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
|
||||
auto v_dram_final = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v), // D
|
||||
make_merge_transform(make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size / kVectorSize,
|
||||
kVectorSize))), // TotalSeqK
|
||||
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_final,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
|
||||
// Logical View for Pipeline: (D, TotalSeqK)
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
|
||||
return pad_tensor_view(
|
||||
// Merge to (D, TotalSeqK) in a single transform:
|
||||
// physical (Page, S, D) -> logical (D, TotalSeqK)
|
||||
auto v_dram_final = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_merge_transform(
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size))),
|
||||
make_tuple(sequence<2>{}, sequence<0, 1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_final,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV_, kPadSeqLenK>{});
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
@@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
const index_t stride_k_for_pipeline =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
|
||||
? kVectorSize
|
||||
: kargs.stride_k;
|
||||
const index_t stride_v_for_pipeline =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
|
||||
? kargs.hdim_v
|
||||
: kargs.stride_v;
|
||||
|
||||
auto o_acc_tile = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
@@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
@@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -6,12 +6,82 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename OffsetVecType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
index_t kLog2PageSize,
|
||||
index_t kLoopStart,
|
||||
index_t kLoopCount,
|
||||
index_t kLoopStride,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
|
||||
const index_t& stride_kv,
|
||||
const index_t& page_stride_kv,
|
||||
const CoordVecType& coord_vec,
|
||||
OffsetVecType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
{
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// for k offsets
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t page_offset = global_token_idx & kInPageOffsetMask;
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
|
||||
static_cast<long_index_t>(page_offset) * stride_kv;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// for v offsets
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
|
||||
const long_index_t page_loc =
|
||||
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
|
||||
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t page_offset =
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
|
||||
kInPageOffsetMask;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout offset
|
||||
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
|
||||
const index_t s = page_offset;
|
||||
const index_t D = stride_kv;
|
||||
|
||||
const long_index_t s_offset =
|
||||
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
|
||||
(s % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_loc + s_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
template <typename Problem_,
|
||||
@@ -41,19 +111,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
|
||||
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
|
||||
static constexpr index_t kVectorSize = Problem::kVectorSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kPageBlockSize % kN0 == 0,
|
||||
"V offset assumes each tile stays within a page; kPageBlockSize must be "
|
||||
"divisible by kN0.");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
@@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
index_t current_seq_k = seqlen_k_start;
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
|
||||
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
@@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
|
||||
statically_indexed_array<index_t, V_KRepeat> v_offsets;
|
||||
(void)stride_k;
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
async_load_fence();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(
|
||||
@@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto p = [&]() {
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
[&variant, &variant_params, &block_indices](auto& x) {
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
|
||||
auto& x) {
|
||||
x = variant.LogitsTransform(variant_params,
|
||||
variant.QueryTransform(variant_params, x),
|
||||
block_indices.batch_idx,
|
||||
@@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
|
||||
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
|
||||
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
|
||||
// Avoid data hazard if v_mfma is followed by inline asm consumer
|
||||
// instructions. In this case, compiler won't add s_nop for us
|
||||
if(i == s_acc.thread_buf_.size() / 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Avoid data hazard if v_mfma is followed by inline asm consumer
|
||||
// instructions. In this case, compiler won't add s_nop for us
|
||||
if(i == s_acc.thread_buf_.size() / 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
#endif
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(
|
||||
v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] =
|
||||
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration. alibi does not have this problem
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); },
|
||||
m,
|
||||
m_old,
|
||||
m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(
|
||||
v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0,
|
||||
kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
2 * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration. alibi does not have this problem
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
}();
|
||||
#else
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
dropout
|
||||
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
|
||||
@@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
|
||||
v_coord[VPageIndexDim] + k0.value] *
|
||||
stride_v;
|
||||
});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
(2 + i_k1.value) * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
page_idx += kN0;
|
||||
current_seq_k += kN0;
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
@@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
@@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
bool kUseTrLoad_,
|
||||
int kPageBlockSize_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBatchPrefillPipelineProblem
|
||||
: public BlockFmhaPipelineProblem<QDataType_,
|
||||
KDataType_,
|
||||
VDataType_,
|
||||
SaccDataType_,
|
||||
SMPLComputeDataType_,
|
||||
BiasDataType_,
|
||||
RandValOutputDataType_,
|
||||
LSEDataType_,
|
||||
PDataType_,
|
||||
OaccDataType_,
|
||||
ODataType_,
|
||||
BlockFmhaShape_,
|
||||
kIsGroupMode_,
|
||||
AttentionVariant_,
|
||||
FmhaMask_,
|
||||
kUseTrLoad_,
|
||||
Traits_>
|
||||
{
|
||||
static constexpr index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
|
||||
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
|
||||
"kPageBlockSize must be power of two");
|
||||
static constexpr index_t kLog2PageSize = []() constexpr {
|
||||
index_t shift = 0;
|
||||
index_t val = kPageBlockSize_;
|
||||
while(val > 1)
|
||||
{
|
||||
val >>= 1;
|
||||
shift++;
|
||||
}
|
||||
return shift;
|
||||
}();
|
||||
|
||||
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
|
||||
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
|
||||
static constexpr bool kIsVectorizedLayout =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
|
||||
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
|
||||
"kQKHeaddim must be divisible by kVectorSize");
|
||||
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
|
||||
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
|
||||
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
@@ -40,6 +41,48 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* padding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* padding for hdim_v */,
|
||||
bool kHasLogitsSoftCap_,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
index_t kPageBlockSize_ = 1,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
kPadSeqLenK_,
|
||||
kPadHeadDimQ_,
|
||||
kPadHeadDimV_,
|
||||
kHasLogitsSoftCap_,
|
||||
BiasEnum_,
|
||||
kHasBiasGrad_,
|
||||
kStoreLSE_,
|
||||
kHasDropout_,
|
||||
QScaleEnum_,
|
||||
kBlockPerCu_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
|
||||
"Batch prefill only supports vectorized or linear KV cache layout.");
|
||||
static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0),
|
||||
"kPageBlockSize should be a power of 2 to support efficient page-based KV cache "
|
||||
"addressing.");
|
||||
};
|
||||
|
||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
index_t kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
|
||||
Reference in New Issue
Block a user