mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Merge commit 'cc75a1dc5f18613af29d8821375f79b0f3c6410b' into develop
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// KV cache memory layout selector.
|
||||
//
|
||||
// Layout summary (kVectorSize = 16 / sizeof(KDataType)):
|
||||
// - VECTORIZED_LAYOUT (swizzled):
|
||||
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize]
|
||||
// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// - LINEAR_LAYOUT:
|
||||
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
|
||||
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
|
||||
enum class BlockAttentionKVCacheMemoryLayoutEnum
|
||||
{
|
||||
VECTORIZED_LAYOUT = 0,
|
||||
LINEAR_LAYOUT = 1,
|
||||
};
|
||||
|
||||
// KV cache lookup table layout selector.
|
||||
// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq]
|
||||
// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1])
|
||||
enum class BlockAttentionKVCacheLookupTableEnum
|
||||
{
|
||||
VLLM_BLOCK_TABLE_2D = 0,
|
||||
SGLANG_PAGE_TABLE_1D = 1,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
@@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
|
||||
static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
|
||||
static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct FmhaFwdEmptyKargs
|
||||
@@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct SglangPageTableKargs
|
||||
{
|
||||
const int32_t* kv_indptr;
|
||||
const int32_t* kv_page_indices;
|
||||
const int32_t* kv_last_page_lens;
|
||||
};
|
||||
|
||||
struct VllmPageTableKargs
|
||||
{
|
||||
const int32_t* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
};
|
||||
|
||||
using PageBlockTableKargs =
|
||||
std::conditional_t<kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
SglangPageTableKargs,
|
||||
VllmPageTableKargs>;
|
||||
|
||||
struct FmhaFwdCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
@@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
|
||||
int32_t num_total_pages;
|
||||
const int32_t* kv_indptr;
|
||||
const int32_t* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#else
|
||||
static constexpr ck_tile::index_t page_block_size = 1;
|
||||
#endif
|
||||
PageBlockTableKargs page_table;
|
||||
|
||||
float scale_s;
|
||||
|
||||
@@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
const PageBlockTableKargs& page_table,
|
||||
float scale_s,
|
||||
[[maybe_unused]] float scale_p,
|
||||
[[maybe_unused]] float scale_o,
|
||||
@@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_total_pages,
|
||||
reinterpret_cast<const int32_t*>(kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(kv_page_indices),
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
reinterpret_cast<const int32_t*>(kv_last_page_lens),
|
||||
page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
const PageBlockTableKargs& page_table,
|
||||
float scale_s,
|
||||
[[maybe_unused]] float scale_p,
|
||||
[[maybe_unused]] float scale_o,
|
||||
@@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_total_pages,
|
||||
reinterpret_cast<const int32_t*>(kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(kv_page_indices),
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
reinterpret_cast<const int32_t*>(kv_last_page_lens),
|
||||
page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
|
||||
#endif
|
||||
const index_t seqlen_k = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
|
||||
const int32_t num_page_blocks = page_end - page_start;
|
||||
const int32_t last_page_len = [&]() {
|
||||
if constexpr(kPageBlockSize == 1)
|
||||
return static_cast<int32_t>(kPageBlockSize);
|
||||
else
|
||||
return kargs.page_table.kv_last_page_lens[i_batch];
|
||||
}();
|
||||
return num_page_blocks > 0
|
||||
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
|
||||
last_page_len)
|
||||
: 0;
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
if(kargs.page_table.seqlen_k_ptr != nullptr)
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
|
||||
else
|
||||
return kargs.seqlen_k;
|
||||
}
|
||||
}();
|
||||
const int32_t* page_idx = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
return kargs.page_table.block_table_ptr +
|
||||
static_cast<long_index_t>(i_batch) *
|
||||
kargs.page_table.batch_stride_block_table;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
@@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
@@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
return;
|
||||
}
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
kargs.seqlen_k = seqlen_k;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
@@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
kargs.seqlen_k = seqlen_k;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
// Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
|
||||
// Logical View for Pipeline: (TotalSeqK, D)
|
||||
|
||||
// Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
|
||||
// PageBlockSize, kVectorSize)
|
||||
// Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages,
|
||||
kargs.hdim_q / kVectorSize,
|
||||
kargs.page_block_size,
|
||||
kVectorSize),
|
||||
make_tuple(
|
||||
kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
// Merge to (TotalSeqK, D) in a single transform:
|
||||
// physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
|
||||
auto k_dram_2d = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size)), // TotalSeqK
|
||||
make_merge_transform(
|
||||
make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
|
||||
static_cast<int32_t>(kVectorSize)))), // D
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
k_dram_2d,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
|
||||
// Logical View for Pipeline: (TotalSeqK, D)
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
// Merge to (TotalSeqK, D) in a single transform:
|
||||
// physical (Page, S, D) -> logical (TotalSeqK, D)
|
||||
auto k_dram_2d = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size)),
|
||||
make_pass_through_transform(kargs.hdim_q)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_2d,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
|
||||
// Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
|
||||
|
||||
// Define the naive physical view with 4D shape: (NumPages,
|
||||
// PageBlockSize/kVectorSize, HeadDim, kVectorSize)
|
||||
// Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size / kVectorSize,
|
||||
kargs.hdim_v,
|
||||
kVectorSize),
|
||||
make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
// Merge to (D, TotalSeqK) in a single transform:
|
||||
// physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
|
||||
auto v_dram_final = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v), // D
|
||||
make_merge_transform(make_tuple(kargs.num_total_pages,
|
||||
kargs.page_block_size / kVectorSize,
|
||||
kVectorSize))), // TotalSeqK
|
||||
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_final,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
|
||||
// Logical View for Pipeline: (D, TotalSeqK)
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
|
||||
return pad_tensor_view(
|
||||
// Merge to (D, TotalSeqK) in a single transform:
|
||||
// physical (Page, S, D) -> logical (D, TotalSeqK)
|
||||
auto v_dram_final = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_merge_transform(
|
||||
make_tuple(kargs.num_total_pages, kargs.page_block_size))),
|
||||
make_tuple(sequence<2>{}, sequence<0, 1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_final,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV_, kPadSeqLenK>{});
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
@@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
const index_t stride_k_for_pipeline =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
|
||||
? kVectorSize
|
||||
: kargs.stride_k;
|
||||
const index_t stride_v_for_pipeline =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
|
||||
? kargs.hdim_v
|
||||
: kargs.stride_v;
|
||||
|
||||
auto o_acc_tile = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
@@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
@@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -6,12 +6,82 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename OffsetVecType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
index_t kLog2PageSize,
|
||||
index_t kLoopStart,
|
||||
index_t kLoopCount,
|
||||
index_t kLoopStride,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
|
||||
const index_t& stride_kv,
|
||||
const index_t& page_stride_kv,
|
||||
const CoordVecType& coord_vec,
|
||||
OffsetVecType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
{
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// for k offsets
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t page_offset = global_token_idx & kInPageOffsetMask;
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
|
||||
static_cast<long_index_t>(page_offset) * stride_kv;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// for v offsets
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
|
||||
const long_index_t page_loc =
|
||||
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
|
||||
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t page_offset =
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
|
||||
kInPageOffsetMask;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout offset
|
||||
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
|
||||
const index_t s = page_offset;
|
||||
const index_t D = stride_kv;
|
||||
|
||||
const long_index_t s_offset =
|
||||
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
|
||||
(s % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_loc + s_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
template <typename Problem_,
|
||||
@@ -41,19 +111,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
|
||||
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
|
||||
static constexpr index_t kVectorSize = Problem::kVectorSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kPageBlockSize % kN0 == 0,
|
||||
"V offset assumes each tile stays within a page; kPageBlockSize must be "
|
||||
"divisible by kN0.");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
@@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
index_t current_seq_k = seqlen_k_start;
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
|
||||
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
@@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
|
||||
statically_indexed_array<index_t, V_KRepeat> v_offsets;
|
||||
(void)stride_k;
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
async_load_fence();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(
|
||||
@@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto p = [&]() {
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
[&variant, &variant_params, &block_indices](auto& x) {
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
|
||||
auto& x) {
|
||||
x = variant.LogitsTransform(variant_params,
|
||||
variant.QueryTransform(variant_params, x),
|
||||
block_indices.batch_idx,
|
||||
@@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
|
||||
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
|
||||
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
|
||||
// Avoid data hazard if v_mfma is followed by inline asm consumer
|
||||
// instructions. In this case, compiler won't add s_nop for us
|
||||
if(i == s_acc.thread_buf_.size() / 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Avoid data hazard if v_mfma is followed by inline asm consumer
|
||||
// instructions. In this case, compiler won't add s_nop for us
|
||||
if(i == s_acc.thread_buf_.size() / 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
#endif
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(
|
||||
v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] =
|
||||
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration. alibi does not have this problem
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); },
|
||||
m,
|
||||
m_old,
|
||||
m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(
|
||||
v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0,
|
||||
kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
2 * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration. alibi does not have this problem
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
}();
|
||||
#else
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
dropout
|
||||
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
|
||||
@@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
|
||||
v_coord[VPageIndexDim] + k0.value] *
|
||||
stride_v;
|
||||
});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
(2 + i_k1.value) * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
page_idx += kN0;
|
||||
current_seq_k += kN0;
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
@@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
@@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
bool kUseTrLoad_,
|
||||
int kPageBlockSize_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBatchPrefillPipelineProblem
|
||||
: public BlockFmhaPipelineProblem<QDataType_,
|
||||
KDataType_,
|
||||
VDataType_,
|
||||
SaccDataType_,
|
||||
SMPLComputeDataType_,
|
||||
BiasDataType_,
|
||||
RandValOutputDataType_,
|
||||
LSEDataType_,
|
||||
PDataType_,
|
||||
OaccDataType_,
|
||||
ODataType_,
|
||||
BlockFmhaShape_,
|
||||
kIsGroupMode_,
|
||||
AttentionVariant_,
|
||||
FmhaMask_,
|
||||
kUseTrLoad_,
|
||||
Traits_>
|
||||
{
|
||||
static constexpr index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
|
||||
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
|
||||
"kPageBlockSize must be power of two");
|
||||
static constexpr index_t kLog2PageSize = []() constexpr {
|
||||
index_t shift = 0;
|
||||
index_t val = kPageBlockSize_;
|
||||
while(val > 1)
|
||||
{
|
||||
val >>= 1;
|
||||
shift++;
|
||||
}
|
||||
return shift;
|
||||
}();
|
||||
|
||||
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
|
||||
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
|
||||
static constexpr bool kIsVectorizedLayout =
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
|
||||
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
|
||||
"kQKHeaddim must be divisible by kVectorSize");
|
||||
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
|
||||
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
|
||||
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
@@ -40,6 +41,48 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* padding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* padding for hdim_v */,
|
||||
bool kHasLogitsSoftCap_,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
index_t kPageBlockSize_ = 1,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
kPadSeqLenK_,
|
||||
kPadHeadDimQ_,
|
||||
kPadHeadDimV_,
|
||||
kHasLogitsSoftCap_,
|
||||
BiasEnum_,
|
||||
kHasBiasGrad_,
|
||||
kStoreLSE_,
|
||||
kHasDropout_,
|
||||
QScaleEnum_,
|
||||
kBlockPerCu_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
|
||||
"Batch prefill only supports vectorized or linear KV cache layout.");
|
||||
static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0),
|
||||
"kPageBlockSize should be a power of 2 to support efficient page-based KV cache "
|
||||
"addressing.");
|
||||
};
|
||||
|
||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
index_t kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
|
||||
Reference in New Issue
Block a user