mha_batch_prefill support page_block_size=16 for vllm

This commit is contained in:
yanguahe
2025-07-08 02:46:27 +00:00
parent ee92958ec0
commit 9beec805ba
2 changed files with 604 additions and 113 deletions

View File

@@ -709,7 +709,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
// const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
kargs.seqlen_k = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
#if 0 // we assume page_block_size=1 for now
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
#endif
@@ -720,7 +721,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
batch_offset_q = query_start * kargs.stride_q;
if constexpr(kIsSglangLayout)
{
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
@@ -762,11 +762,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
}
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
// #if 0 // we assume page_block_size=1 for now
// kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size +
// last_page_len;
// #else
// kargs.seqlen_k = num_page_blocks;
// #endif
}
else
{
@@ -789,11 +790,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
// #if 0 // we assume page_block_size=1 for now
// kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size +
// last_page_len;
// #else
// kargs.seqlen_k = num_page_blocks;
// #endif
}
// for simplicity, batch stride we just modify the pointer
@@ -861,10 +863,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(
make_pass_through_transform(kargs.hdim_v),
// make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
make_pass_through_transform(kargs.num_total_pages)),
make_tuple(make_pass_through_transform(kargs.hdim_v),
// make_pass_through_transform(kargs.num_total_pages *
// kargs.page_block_size)),
make_pass_through_transform(kargs.num_total_pages)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));

View File

@@ -11,8 +11,178 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
// #define USED_VLLM_PAGE_TABLE_VERSION 0
// #define USED_VLLM_PAGE_TABLE_VERSION 1
// #define USED_VLLM_PAGE_TABLE_VERSION 2
#define USED_VLLM_PAGE_TABLE_VERSION 3
namespace ck_tile {
union DoubleIndext
{
index_t idx2[2];
long_index_t ldx;
};
template <typename OffsetVecType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
index_t kPageShiftSize,
index_t kLoopStart,
index_t kLoopCount,
index_t kLoopStride,
bool kIsSglangLayout,
bool kIsKcache>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
const index_t& stride_kv,
const CoordVecType& coord_vec,
OffsetVecType& kv_offset_vec)
{
const index_t& thread_coord_start = coord_vec[kCoordAxis];
// if(blockIdx.x + blockIdx.y + blockIdx.z + threadIdx.y + threadIdx.z == 0 && threadIdx.x == 0)
// if(blockIdx.x + blockIdx.y + threadIdx.y + threadIdx.z == 0 && blockIdx.z == 3 &&
// threadIdx.x == 102)
// {
// printf("kIsSglangLayout=%d\n", kIsSglangLayout);
// if constexpr(kIsKcache)
// {
// printf("k_id: blkz=%d, thr_idx=%d, thr_coord_st=%d\n",
// blockIdx.z,
// threadIdx.x,
// thread_coord_start);
// }
// else
// {
// printf("v_id: blkz=%d, thr_idx=%d, thr_coord_st=%d\n",
// blockIdx.z,
// threadIdx.x,
// thread_coord_start);
// }
// }
if constexpr(kIsSglangLayout)
{
static_for<0, kLoopCount, 1>{}([&](auto k0) {
kv_offset_vec[k0] =
page_vec[thread_coord_start + kLoopStart + kLoopStride * k0.value] * stride_kv;
});
}
else
{
#if USED_VLLM_PAGE_TABLE_VERSION == 3
constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
if constexpr(kIsKcache)
{
// for k_offset_vec
constexpr index_t kPageStride = kLoopStride >> kPageShiftSize;
// constexpr array<index_t, kLoopCount> kPageIdArray = []() {
// array<index_t, kLoopCount> arr;
// static_for<0, kLoopCount, 1>{}([&](auto k0) {
// constexpr index_t kPageId = kPageStride * k0.value;
// arr[k0] = kPageId;
// });
// return arr;
// }();
static_for<0, kLoopCount, 1>{}([&](auto k0) {
// constexpr index_t kPageId = kPageIdArray[k0];
constexpr index_t kPageId = kPageStride * k0.value;
const index_t page_offset =
(thread_coord_start + kLoopStride * k0.value) & kPageMask;
kv_offset_vec[k0] =
((page_vec[kPageId] << kPageShiftSize) + page_offset) * stride_kv;
});
}
else
{
// for v_offset_vec
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id = (lane0_start + kLoopStart) >> kPageShiftSize;
const index_t page_loc = page_vec[lane0_page_id] << kPageShiftSize;
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t page_offset =
(thread_coord_start + kLoopStart + k0.value) & kPageMask;
kv_offset_vec[k0] = (page_loc + page_offset) * stride_kv;
});
}
#else
index_t i_page;
index_t i_seq;
static_for<0, kLoopCount, 1>{}([&](auto k0) {
#if USED_VLLM_PAGE_TABLE_VERSION == 0
int32_t seqlen_v_idx_per_repeat =
thread_coord_start + kLoopStart + kLoopStride * k0.value;
i_page = seqlen_v_idx_per_repeat / kPageBlockSize;
i_seq = seqlen_v_idx_per_repeat % kPageBlockSize;
kv_offset_vec[k0] = (page_vec[i_page] * kPageBlockSize + i_seq) * stride_kv;
#elif USED_VLLM_PAGE_TABLE_VERSION == 1
if constexpr(kIsKcache)
{
// for k_offset_vec
constexpr index_t kItemOffset =
(kLoopStart + kLoopStride * k0.value) >> kPageShiftSize;
i_page = (thread_coord_start >> kPageShiftSize) + kItemOffset;
kv_offset_vec[k0] =
((page_vec[i_page] << kPageShiftSize) + thread_coord_start) * stride_kv;
}
else
{
// for v_offset_vec
constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
int32_t seqlen_v_idx_per_repeat =
thread_coord_start + kLoopStart + kLoopStride * k0.value;
i_page = seqlen_v_idx_per_repeat >> kPageShiftSize;
i_seq = seqlen_v_idx_per_repeat & kPageMask;
kv_offset_vec[k0] = ((page_vec[i_page] << kPageShiftSize) + i_seq) * stride_kv;
}
#elif USED_VLLM_PAGE_TABLE_VERSION == 2
if constexpr(kIsKcache)
{
// for k_offset_vec
// index_t i_page;
constexpr index_t kItemOffset = kLoopStride * k0.value >> kPageShiftSize;
// i_page = (thread_coord_start >> kPageShiftSize) + kItemOffset;
asm volatile("v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
% [thread_coord_start]\n\t " " v_add_u32_e32 % [i_page],
% [kItemOffset],
% [i_page]\n\t " : [i_page] " + v "(i_page) : [thread_coord_start]
"v"(thread_coord_start),
[kItemOffset] "i"(kItemOffset),
[kPageShiftSize] "i"(kPageShiftSize));
kv_offset_vec[k0] =
((page_vec[i_page] << kPageShiftSize) + thread_coord_start) * stride_kv;
}
else
{
constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
constexpr index_t kItemOffset = kLoopStart + kLoopStride * k0.value;
// i_page = thread_coord_start + kItemOffset
// i_seq = i_page & (kPageBlockSize - 1)
// i_page = i_page >> log2(kPageBlockSize)
asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset],
% [thread_coord_start]\n\t
"
"v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t"
"v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
% [i_page]\n\t " : [i_page] " +
v "(i_page), [i_seq] " = v "(i_seq)
: [thread_coord_start] "v"(thread_coord_start),
[kItemOffset] "i"(kItemOffset),
[kPageMask] "i"(kPageMask),
[kPageShiftSize] "i"(kPageShiftSize));
kv_offset_vec[k0] = ((page_vec[i_page] << kPageShiftSize) + i_seq) * stride_kv;
}
#endif
});
#endif
}
}
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_,
typename Policy_ = BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy>
@@ -41,17 +211,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = 16;
// static constexpr index_t kPageBlockSize = 1;
static constexpr index_t kPageShiftSize = 4;
// static constexpr index_t kPageShiftSize = 0;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
@@ -65,8 +241,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
// static constexpr bool kIsSglangLayout = Problem::kIsSglangLayout;
static constexpr bool kIsSglangLayout = Problem::kIsSglangLayout;
// static constexpr bool kIsSglangLayout = true;
static constexpr bool kIsChunkedPrefill = Problem::kIsChunkedPrefill;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
@@ -126,6 +302,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
// return 1;
return 2;
}
else if constexpr(kQKHeaddim <= 192)
@@ -327,21 +504,71 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
if constexpr(kIsSglangLayout)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value;
int32_t i_page = seqlen_k_idx_per_repeat / page_block_size;
int32_t i_seq = seqlen_k_idx_per_repeat % page_block_size;
k_offsets[n0] = (page_idx[i_page] * page_block_size + i_seq) * stride_k;
});
}
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
kPageShiftSize,
0,
NRepeat,
kN0 / NRepeat,
kIsSglangLayout,
true>(page_idx, stride_k, k_coord, k_offsets);
// if constexpr(kIsSglangLayout)
// {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] *
// stride_k;
// });
// }
// else
// {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
// int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value;
// int32_t i_page = seqlen_k_idx_per_repeat /
// kPageBlockSize; int32_t i_seq = seqlen_k_idx_per_repeat
// % kPageBlockSize; k_offsets[n0] = (page_idx[i_page] * kPageBlockSize +
// i_seq) * stride_k;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
// // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// // int32_t seqlen_k_idx_per_repeat = k_coord[0] + n0.value * kN0 /
// NRepeat;
// // int32_t i_page = seqlen_k_idx_per_repeat >>
// kPageShiftSize;
// // int32_t i_seq = seqlen_k_idx_per_repeat & kPageMask;
// // k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
// stride_k;
// constexpr index_t kItemOffset = n0.value * kN0 / NRepeat >>
// kPageShiftSize; int32_t i_page = (k_coord[0] >>
// kPageShiftSize) + kItemOffset; k_offsets[n0] = ((page_idx[i_page] <<
// kPageShiftSize) + k_coord[0]) * stride_k;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// constexpr index_t kItemOffset = n0.value * kN0 / NRepeat;
// index_t i_page;
// index_t i_seq;
// // i_page = k_coord[0] + kItemOffset
// // i_seq = i_page & (kPageBlockSize - 1)
// // i_page = i_page >> log2(kPageBlockSize)
// asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], %[k_coord_0]\n\t"
// "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t"
// "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
// % [i_page]\n\t " : [i_page] " +
// v "(i_page), [i_seq] " = v "(i_seq)
// : [k_coord_0] "v"(k_coord[0]),
// [kItemOffset] "i"(kItemOffset),
// [kPageMask] "i"(kPageMask),
// [kPageShiftSize] "i"(kPageShiftSize));
// k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
// stride_k;
// #endif
// });
// }
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
@@ -375,22 +602,62 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
(void)stride_k;
if constexpr(kIsSglangLayout)
{
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
});
}
else
{
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value;
int32_t i_page = seqlen_v_idx_per_repeat / page_block_size;
int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size;
v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v;
});
}
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kPageShiftSize,
0,
V_KRepeat,
1,
kIsSglangLayout,
false>(page_idx, stride_v, v_coord, v_offsets);
// // (void)stride_k;
// if constexpr(kIsSglangLayout)
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
// });
// }
// else
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
// int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value;
// int32_t i_page = seqlen_v_idx_per_repeat /
// kPageBlockSize; int32_t i_seq = seqlen_v_idx_per_repeat
// % kPageBlockSize; v_offsets[k0] = (page_idx[i_page] * kPageBlockSize +
// i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value;
// int32_t i_page = seqlen_v_idx_per_repeat >>
// kPageShiftSize; int32_t i_seq = seqlen_v_idx_per_repeat
// & kPageMask; v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) +
// i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// constexpr index_t kItemOffset = k0.value;
// index_t i_page;
// index_t i_seq;
// // i_page = v_coord_i + kItemOffset
// // i_seq = i_page & (kPageBlockSize - 1)
// // i_page = i_page >> log2(kPageBlockSize)
// asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], %[v_coord_i]\n\t"
// "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t"
// "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
// %[i_page]\n\t" : [i_page] "+v"(i_page), [i_seq] "=v"(i_seq)
// : [v_coord_i] "v"(v_coord[VPageIndexDim]),
// [kItemOffset] "i"(kItemOffset),
// [kPageMask] "i"(kPageMask),
// [kPageShiftSize] "i"(kPageShiftSize));
// v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
// stride_v;
// #endif
// });
// }
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -454,23 +721,66 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
if constexpr(kIsSglangLayout)
{
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
}
else
{
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
int32_t seqlen_v_idx_per_repeat = kK1 + v_coord[VPageIndexDim] + k0.value;
int32_t i_page = seqlen_v_idx_per_repeat / page_block_size;
int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size;
v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v;
});
}
v_dram_window.update_page_idx(v_offsets);
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kPageShiftSize,
kK1,
V_KRepeat,
1,
kIsSglangLayout,
false>(page_idx, stride_v, v_coord, v_offsets);
// if constexpr(kIsSglangLayout)
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value]
// * stride_v;
// });
// }
// else
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
// int32_t seqlen_v_idx_per_repeat = kK1 + v_coord[VPageIndexDim] +
// k0.value; int32_t i_page =
// seqlen_v_idx_per_repeat / kPageBlockSize; int32_t i_seq =
// seqlen_v_idx_per_repeat % kPageBlockSize; v_offsets[k0] =
// (page_idx[i_page] * kPageBlockSize + i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + kK1 +
// k0.value; int32_t i_page =
// seqlen_v_idx_per_repeat >> kPageShiftSize; int32_t i_seq =
// seqlen_v_idx_per_repeat & kPageMask; v_offsets[k0] =
// ((page_idx[i_page] << kPageShiftSize) + i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// constexpr index_t kItemOffset = kK1 + k0.value;
// index_t i_page;
// index_t i_seq;
// // i_page = v_coord_i + kItemOffset
// // i_seq = i_page & (kPageBlockSize - 1)
// // i_page = i_page >> log2(kPageBlockSize)
// asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset],
// %[v_coord_i]\n\t"
// "v_and_b32_e32 %[i_seq], %[kPageMask],
// %[i_page]\n\t" "v_lshrrev_b32_e32 %[i_page],
// %[kPageShiftSize], %[i_page]\n\t" : [i_page]
// "+v"(i_page), [i_seq] "=v"(i_seq) : [v_coord_i]
// "v"(v_coord[VPageIndexDim]),
// [kItemOffset] "i"(kItemOffset),
// [kPageMask] "i"(kPageMask),
// [kPageShiftSize] "i"(kPageShiftSize));
// v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
// stride_v;
// #endif
// });
// }
v_dram_window.update_page_idx(v_offsets);
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
@@ -625,10 +935,68 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] =
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kPageShiftSize,
2 * kK1,
V_KRepeat,
1,
kIsSglangLayout,
false>(page_idx, stride_v, v_coord, v_offsets);
// if constexpr(kIsSglangLayout)
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// v_offsets[k0] =
// page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value]
// * stride_v;
// });
// }
// else
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
// int32_t seqlen_v_idx_per_repeat =
// kK1 * 2 + v_coord[VPageIndexDim] + k0.value;
// int32_t i_page = seqlen_v_idx_per_repeat /
// kPageBlockSize; int32_t i_seq = seqlen_v_idx_per_repeat
// % kPageBlockSize; v_offsets[k0] = (page_idx[i_page] *
// kPageBlockSize + i_seq) * stride_k;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
// int32_t seqlen_v_idx_per_repeat =
// v_coord[VPageIndexDim] + 2 * kK1 + k0.value;
// int32_t i_page = seqlen_v_idx_per_repeat >>
// kPageShiftSize; int32_t i_seq = seqlen_v_idx_per_repeat
// & kPageMask; v_offsets[k0] = ((page_idx[i_page] <<
// kPageShiftSize) + i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
// constexpr index_t kPageMask = (1 << kPageShiftSize) -
// 1; constexpr index_t kItemOffset = 2 * kK1 + k0.value;
// index_t i_page;
// index_t i_seq;
// // i_page = v_coord_i + kItemOffset
// // i_seq = i_page & (kPageBlockSize - 1)
// // i_page = i_page >> log2(kPageBlockSize)
// asm volatile(
// "v_add_u32_e32 %[i_page], %[kItemOffset],
// %[v_coord_i]\n\t" "v_and_b32_e32 %[i_seq],
// %[kPageMask], %[i_page]\n\t" "v_lshrrev_b32_e32
// %[i_page], %[kPageShiftSize], %[i_page]\n\t" :
// [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) :
// [v_coord_i] "v"(v_coord[VPageIndexDim]),
// [kItemOffset] "i"(kItemOffset),
// [kPageMask] "i"(kPageMask),
// [kPageShiftSize] "i"(kPageShiftSize));
// v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) +
// i_seq) * stride_v;
// #endif
// });
// }
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
@@ -749,24 +1117,85 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
if constexpr(kIsSglangLayout)
{
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
v_coord[VPageIndexDim] + k0.value] *
stride_v;
});
}
else
{
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
int32_t seqlen_v_idx_per_repeat = kK1 * 2 + i_k1.value * kK1 +
v_coord[VPageIndexDim] + k0.value;
int32_t i_page = seqlen_v_idx_per_repeat / page_block_size;
int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size;
v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v;
});
}
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kPageShiftSize,
(2 + i_k1.value) * kK1,
V_KRepeat,
1,
kIsSglangLayout,
false>(page_idx, stride_v, v_coord, v_offsets);
// if constexpr(kIsSglangLayout)
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// v_offsets[k0] = page_idx[kK1 * 2 +
// i_k1.value * kK1 +
// v_coord[VPageIndexDim]
// + k0.value] *
// stride_v;
// });
// }
// else
// {
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
// int32_t seqlen_v_idx_per_repeat =
// kK1 * 2 + i_k1.value * kK1 +
// v_coord[VPageIndexDim] + k0.value;
// int32_t i_page = seqlen_v_idx_per_repeat
// / kPageBlockSize; int32_t i_seq =
// seqlen_v_idx_per_repeat % kPageBlockSize;
// v_offsets[k0] =
// (page_idx[i_page] * kPageBlockSize +
// i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
// constexpr index_t kPageMask = (1 <<
// kPageShiftSize) - 1; int32_t
// seqlen_v_idx_per_repeat =
// v_coord[VPageIndexDim] + 2 * kK1 +
// i_k1.value * kK1 + k0.value;
// int32_t i_page = seqlen_v_idx_per_repeat
// >> kPageShiftSize; int32_t i_seq =
// seqlen_v_idx_per_repeat & kPageMask;
// v_offsets[k0] =
// ((page_idx[i_page] << kPageShiftSize)
// + i_seq) * stride_v;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
// constexpr index_t kPageMask = (1 <<
// kPageShiftSize) - 1; constexpr index_t
// kItemOffset =
// 2 * kK1 + i_k1.value * kK1 +
// k0.value;
// index_t i_page;
// index_t i_seq;
// // i_page = v_coord_i + kItemOffset
// // i_seq = i_page & (kPageBlockSize - 1)
// // i_page = i_page >>
// log2(kPageBlockSize) asm volatile(
// "v_add_u32_e32 %[i_page],
// %[kItemOffset], %[v_coord_i]\n\t"
// "v_and_b32_e32 %[i_seq],
// %[kPageMask], %[i_page]\n\t"
// "v_lshrrev_b32_e32 %[i_page],
// %[kPageShiftSize], %[i_page]\n\t" :
// [i_page] "+v"(i_page), [i_seq]
// "=v"(i_seq) : [v_coord_i]
// "v"(v_coord[VPageIndexDim]),
// [kItemOffset] "i"(kItemOffset),
// [kPageMask] "i"(kPageMask),
// [kPageShiftSize]
// "i"(kPageShiftSize));
// v_offsets[k0] =
// ((page_idx[i_page] << kPageShiftSize)
// + i_seq) * stride_v;
// #endif
// });
// }
v_dram_window.update_page_idx(v_offsets);
}
block_sync_lds();
@@ -807,26 +1236,86 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
page_idx += kN0;
if constexpr(kIsSglangLayout)
{
page_idx += kN0;
}
else
{
page_idx += kN0 / kPageBlockSize;
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(kIsSglangLayout)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value;
int32_t i_page = seqlen_k_idx_per_repeat / page_block_size;
int32_t i_seq = seqlen_k_idx_per_repeat % page_block_size;
k_offsets[n0] = (page_idx[i_page] * page_block_size + i_seq) * stride_k;
});
}
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
kPageShiftSize,
0,
NRepeat,
kN0 / NRepeat,
kIsSglangLayout,
true>(page_idx, stride_k, k_coord, k_offsets);
// if constexpr(kIsSglangLayout)
// {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat *
// n0.value] * stride_k;
// });
// }
// else
// {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
// int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 /
// NRepeat * n0.value; int32_t i_page =
// seqlen_k_idx_per_repeat / kPageBlockSize; int32_t i_seq
// = seqlen_k_idx_per_repeat % kPageBlockSize; k_offsets[n0]
// = (page_idx[i_page] * kPageBlockSize + i_seq) * stride_k;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
// // constexpr index_t kPageMask = (1 <<
// kPageShiftSize) - 1;
// // int32_t seqlen_k_idx_per_repeat = k_coord[0] +
// n0.value * kN0 / NRepeat;
// // int32_t i_page =
// seqlen_k_idx_per_repeat >>
// // kPageShiftSize; int32_t i_seq =
// seqlen_k_idx_per_repeat
// // & kPageMask; k_offsets[n0] = ((page_idx[i_page] <<
// kPageShiftSize) +
// // i_seq) * stride_k;
// constexpr index_t kItemOffset = n0.value * kN0 / NRepeat
// >> kPageShiftSize; int32_t i_page = (k_coord[0] >>
// kPageShiftSize) + kItemOffset; k_offsets[n0] =
// ((page_idx[i_page] << kPageShiftSize) + k_coord[0]) *
// stride_k;
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
// constexpr index_t kPageMask = (1 << kPageShiftSize) -
// 1; constexpr index_t kItemOffset = n0.value * kN0 /
// NRepeat; index_t i_page; index_t i_seq;
// // i_page = k_coord[0] + kItemOffset
// // i_seq = i_page & (kPageBlockSize - 1)
// // i_page = i_page >> log2(kPageBlockSize)
// asm volatile(
// "v_add_u32_e32 %[i_page], %[kItemOffset],
// %[k_coord_0]\n\t" "v_and_b32_e32 %[i_seq],
// %[kPageMask], %[i_page]\n\t" "v_lshrrev_b32_e32
// %[i_page], %[kPageShiftSize], %[i_page]\n\t" :
// [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) :
// [k_coord_0] "v"(k_coord[0]),
// [kItemOffset] "i"(kItemOffset),
// [kPageMask] "i"(kPageMask),
// [kPageShiftSize] "i"(kPageShiftSize));
// k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) +
// i_seq) * stride_k;
// #endif
// });
// }
k_dram_window.update_page_idx(k_offsets);
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))