mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
mha_batch_prefill support page_block_size=16 for vllm
This commit is contained in:
@@ -709,7 +709,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
// const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
kargs.seqlen_k = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
|
||||
#endif
|
||||
@@ -720,7 +721,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
|
||||
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
@@ -762,11 +762,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
// #if 0 // we assume page_block_size=1 for now
|
||||
// kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size +
|
||||
// last_page_len;
|
||||
// #else
|
||||
// kargs.seqlen_k = num_page_blocks;
|
||||
// #endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -789,11 +790,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
|
||||
#else
|
||||
kargs.seqlen_k = num_page_blocks;
|
||||
#endif
|
||||
// #if 0 // we assume page_block_size=1 for now
|
||||
// kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size +
|
||||
// last_page_len;
|
||||
// #else
|
||||
// kargs.seqlen_k = num_page_blocks;
|
||||
// #endif
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -861,10 +863,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.hdim_v),
|
||||
// make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
|
||||
make_pass_through_transform(kargs.num_total_pages)),
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
// make_pass_through_transform(kargs.num_total_pages *
|
||||
// kargs.page_block_size)),
|
||||
make_pass_through_transform(kargs.num_total_pages)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
|
||||
@@ -11,8 +11,178 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
// #define USED_VLLM_PAGE_TABLE_VERSION 0
|
||||
// #define USED_VLLM_PAGE_TABLE_VERSION 1
|
||||
// #define USED_VLLM_PAGE_TABLE_VERSION 2
|
||||
#define USED_VLLM_PAGE_TABLE_VERSION 3
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
union DoubleIndext
|
||||
{
|
||||
index_t idx2[2];
|
||||
long_index_t ldx;
|
||||
};
|
||||
|
||||
template <typename OffsetVecType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
index_t kPageShiftSize,
|
||||
index_t kLoopStart,
|
||||
index_t kLoopCount,
|
||||
index_t kLoopStride,
|
||||
bool kIsSglangLayout,
|
||||
bool kIsKcache>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
|
||||
const index_t& stride_kv,
|
||||
const CoordVecType& coord_vec,
|
||||
OffsetVecType& kv_offset_vec)
|
||||
{
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
// if(blockIdx.x + blockIdx.y + blockIdx.z + threadIdx.y + threadIdx.z == 0 && threadIdx.x == 0)
|
||||
// if(blockIdx.x + blockIdx.y + threadIdx.y + threadIdx.z == 0 && blockIdx.z == 3 &&
|
||||
// threadIdx.x == 102)
|
||||
// {
|
||||
// printf("kIsSglangLayout=%d\n", kIsSglangLayout);
|
||||
// if constexpr(kIsKcache)
|
||||
// {
|
||||
// printf("k_id: blkz=%d, thr_idx=%d, thr_coord_st=%d\n",
|
||||
// blockIdx.z,
|
||||
// threadIdx.x,
|
||||
// thread_coord_start);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// printf("v_id: blkz=%d, thr_idx=%d, thr_coord_st=%d\n",
|
||||
// blockIdx.z,
|
||||
// threadIdx.x,
|
||||
// thread_coord_start);
|
||||
// }
|
||||
// }
|
||||
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
kv_offset_vec[k0] =
|
||||
page_vec[thread_coord_start + kLoopStart + kLoopStride * k0.value] * stride_kv;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if USED_VLLM_PAGE_TABLE_VERSION == 3
|
||||
constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// for k_offset_vec
|
||||
constexpr index_t kPageStride = kLoopStride >> kPageShiftSize;
|
||||
// constexpr array<index_t, kLoopCount> kPageIdArray = []() {
|
||||
// array<index_t, kLoopCount> arr;
|
||||
// static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
// constexpr index_t kPageId = kPageStride * k0.value;
|
||||
// arr[k0] = kPageId;
|
||||
// });
|
||||
// return arr;
|
||||
// }();
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
// constexpr index_t kPageId = kPageIdArray[k0];
|
||||
constexpr index_t kPageId = kPageStride * k0.value;
|
||||
const index_t page_offset =
|
||||
(thread_coord_start + kLoopStride * k0.value) & kPageMask;
|
||||
kv_offset_vec[k0] =
|
||||
((page_vec[kPageId] << kPageShiftSize) + page_offset) * stride_kv;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// for v_offset_vec
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id = (lane0_start + kLoopStart) >> kPageShiftSize;
|
||||
const index_t page_loc = page_vec[lane0_page_id] << kPageShiftSize;
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t page_offset =
|
||||
(thread_coord_start + kLoopStart + k0.value) & kPageMask;
|
||||
kv_offset_vec[k0] = (page_loc + page_offset) * stride_kv;
|
||||
});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
index_t i_page;
|
||||
index_t i_seq;
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
#if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
int32_t seqlen_v_idx_per_repeat =
|
||||
thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
i_page = seqlen_v_idx_per_repeat / kPageBlockSize;
|
||||
i_seq = seqlen_v_idx_per_repeat % kPageBlockSize;
|
||||
kv_offset_vec[k0] = (page_vec[i_page] * kPageBlockSize + i_seq) * stride_kv;
|
||||
|
||||
#elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// for k_offset_vec
|
||||
constexpr index_t kItemOffset =
|
||||
(kLoopStart + kLoopStride * k0.value) >> kPageShiftSize;
|
||||
i_page = (thread_coord_start >> kPageShiftSize) + kItemOffset;
|
||||
kv_offset_vec[k0] =
|
||||
((page_vec[i_page] << kPageShiftSize) + thread_coord_start) * stride_kv;
|
||||
}
|
||||
else
|
||||
{
|
||||
// for v_offset_vec
|
||||
constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
int32_t seqlen_v_idx_per_repeat =
|
||||
thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
i_page = seqlen_v_idx_per_repeat >> kPageShiftSize;
|
||||
i_seq = seqlen_v_idx_per_repeat & kPageMask;
|
||||
kv_offset_vec[k0] = ((page_vec[i_page] << kPageShiftSize) + i_seq) * stride_kv;
|
||||
}
|
||||
|
||||
#elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// for k_offset_vec
|
||||
// index_t i_page;
|
||||
constexpr index_t kItemOffset = kLoopStride * k0.value >> kPageShiftSize;
|
||||
// i_page = (thread_coord_start >> kPageShiftSize) + kItemOffset;
|
||||
asm volatile("v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
|
||||
% [thread_coord_start]\n\t " " v_add_u32_e32 % [i_page],
|
||||
% [kItemOffset],
|
||||
% [i_page]\n\t " : [i_page] " + v "(i_page) : [thread_coord_start]
|
||||
"v"(thread_coord_start),
|
||||
[kItemOffset] "i"(kItemOffset),
|
||||
[kPageShiftSize] "i"(kPageShiftSize));
|
||||
kv_offset_vec[k0] =
|
||||
((page_vec[i_page] << kPageShiftSize) + thread_coord_start) * stride_kv;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
constexpr index_t kItemOffset = kLoopStart + kLoopStride * k0.value;
|
||||
// i_page = thread_coord_start + kItemOffset
|
||||
// i_seq = i_page & (kPageBlockSize - 1)
|
||||
// i_page = i_page >> log2(kPageBlockSize)
|
||||
asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset],
|
||||
% [thread_coord_start]\n\t
|
||||
"
|
||||
"v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t"
|
||||
"v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
|
||||
% [i_page]\n\t " : [i_page] " +
|
||||
v "(i_page), [i_seq] " = v "(i_seq)
|
||||
: [thread_coord_start] "v"(thread_coord_start),
|
||||
[kItemOffset] "i"(kItemOffset),
|
||||
[kPageMask] "i"(kPageMask),
|
||||
[kPageShiftSize] "i"(kPageShiftSize));
|
||||
kv_offset_vec[k0] = ((page_vec[i_page] << kPageShiftSize) + i_seq) * stride_kv;
|
||||
}
|
||||
#endif
|
||||
});
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy>
|
||||
@@ -41,17 +211,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kPageBlockSize = 16;
|
||||
// static constexpr index_t kPageBlockSize = 1;
|
||||
|
||||
static constexpr index_t kPageShiftSize = 4;
|
||||
// static constexpr index_t kPageShiftSize = 0;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -65,8 +241,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
// static constexpr bool kIsSglangLayout = Problem::kIsSglangLayout;
|
||||
static constexpr bool kIsSglangLayout = Problem::kIsSglangLayout;
|
||||
// static constexpr bool kIsSglangLayout = true;
|
||||
static constexpr bool kIsChunkedPrefill = Problem::kIsChunkedPrefill;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
@@ -126,6 +302,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
// return 1;
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
@@ -327,21 +504,71 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value;
|
||||
int32_t i_page = seqlen_k_idx_per_repeat / page_block_size;
|
||||
int32_t i_seq = seqlen_k_idx_per_repeat % page_block_size;
|
||||
k_offsets[n0] = (page_idx[i_page] * page_block_size + i_seq) * stride_k;
|
||||
});
|
||||
}
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kPageShiftSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kIsSglangLayout,
|
||||
true>(page_idx, stride_k, k_coord, k_offsets);
|
||||
|
||||
// if constexpr(kIsSglangLayout)
|
||||
// {
|
||||
// static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] *
|
||||
// stride_k;
|
||||
// });
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
// int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value;
|
||||
// int32_t i_page = seqlen_k_idx_per_repeat /
|
||||
// kPageBlockSize; int32_t i_seq = seqlen_k_idx_per_repeat
|
||||
// % kPageBlockSize; k_offsets[n0] = (page_idx[i_page] * kPageBlockSize +
|
||||
// i_seq) * stride_k;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
// // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// // int32_t seqlen_k_idx_per_repeat = k_coord[0] + n0.value * kN0 /
|
||||
// NRepeat;
|
||||
// // int32_t i_page = seqlen_k_idx_per_repeat >>
|
||||
// kPageShiftSize;
|
||||
// // int32_t i_seq = seqlen_k_idx_per_repeat & kPageMask;
|
||||
// // k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
|
||||
// stride_k;
|
||||
|
||||
// constexpr index_t kItemOffset = n0.value * kN0 / NRepeat >>
|
||||
// kPageShiftSize; int32_t i_page = (k_coord[0] >>
|
||||
// kPageShiftSize) + kItemOffset; k_offsets[n0] = ((page_idx[i_page] <<
|
||||
// kPageShiftSize) + k_coord[0]) * stride_k;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// constexpr index_t kItemOffset = n0.value * kN0 / NRepeat;
|
||||
// index_t i_page;
|
||||
// index_t i_seq;
|
||||
// // i_page = k_coord[0] + kItemOffset
|
||||
// // i_seq = i_page & (kPageBlockSize - 1)
|
||||
// // i_page = i_page >> log2(kPageBlockSize)
|
||||
// asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], %[k_coord_0]\n\t"
|
||||
// "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t"
|
||||
// "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
|
||||
// % [i_page]\n\t " : [i_page] " +
|
||||
// v "(i_page), [i_seq] " = v "(i_seq)
|
||||
// : [k_coord_0] "v"(k_coord[0]),
|
||||
// [kItemOffset] "i"(kItemOffset),
|
||||
// [kPageMask] "i"(kPageMask),
|
||||
// [kPageShiftSize] "i"(kPageShiftSize));
|
||||
// k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
|
||||
// stride_k;
|
||||
// #endif
|
||||
// });
|
||||
// }
|
||||
|
||||
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
@@ -375,22 +602,62 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
|
||||
statically_indexed_array<index_t, V_KRepeat> v_offsets;
|
||||
(void)stride_k;
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value;
|
||||
int32_t i_page = seqlen_v_idx_per_repeat / page_block_size;
|
||||
int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size;
|
||||
v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v;
|
||||
});
|
||||
}
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kPageShiftSize,
|
||||
0,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kIsSglangLayout,
|
||||
false>(page_idx, stride_v, v_coord, v_offsets);
|
||||
|
||||
// // (void)stride_k;
|
||||
// if constexpr(kIsSglangLayout)
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
// });
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
// int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value;
|
||||
// int32_t i_page = seqlen_v_idx_per_repeat /
|
||||
// kPageBlockSize; int32_t i_seq = seqlen_v_idx_per_repeat
|
||||
// % kPageBlockSize; v_offsets[k0] = (page_idx[i_page] * kPageBlockSize +
|
||||
// i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value;
|
||||
// int32_t i_page = seqlen_v_idx_per_repeat >>
|
||||
// kPageShiftSize; int32_t i_seq = seqlen_v_idx_per_repeat
|
||||
// & kPageMask; v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) +
|
||||
// i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// constexpr index_t kItemOffset = k0.value;
|
||||
// index_t i_page;
|
||||
// index_t i_seq;
|
||||
// // i_page = v_coord_i + kItemOffset
|
||||
// // i_seq = i_page & (kPageBlockSize - 1)
|
||||
// // i_page = i_page >> log2(kPageBlockSize)
|
||||
// asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], %[v_coord_i]\n\t"
|
||||
// "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t"
|
||||
// "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize],
|
||||
// %[i_page]\n\t" : [i_page] "+v"(i_page), [i_seq] "=v"(i_seq)
|
||||
// : [v_coord_i] "v"(v_coord[VPageIndexDim]),
|
||||
// [kItemOffset] "i"(kItemOffset),
|
||||
// [kPageMask] "i"(kPageMask),
|
||||
// [kPageShiftSize] "i"(kPageShiftSize));
|
||||
// v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
|
||||
// stride_v;
|
||||
// #endif
|
||||
// });
|
||||
// }
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -454,23 +721,66 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
int32_t seqlen_v_idx_per_repeat = kK1 + v_coord[VPageIndexDim] + k0.value;
|
||||
int32_t i_page = seqlen_v_idx_per_repeat / page_block_size;
|
||||
int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size;
|
||||
v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v;
|
||||
});
|
||||
}
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kPageShiftSize,
|
||||
kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kIsSglangLayout,
|
||||
false>(page_idx, stride_v, v_coord, v_offsets);
|
||||
|
||||
// if constexpr(kIsSglangLayout)
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value]
|
||||
// * stride_v;
|
||||
// });
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
// int32_t seqlen_v_idx_per_repeat = kK1 + v_coord[VPageIndexDim] +
|
||||
// k0.value; int32_t i_page =
|
||||
// seqlen_v_idx_per_repeat / kPageBlockSize; int32_t i_seq =
|
||||
// seqlen_v_idx_per_repeat % kPageBlockSize; v_offsets[k0] =
|
||||
// (page_idx[i_page] * kPageBlockSize + i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + kK1 +
|
||||
// k0.value; int32_t i_page =
|
||||
// seqlen_v_idx_per_repeat >> kPageShiftSize; int32_t i_seq =
|
||||
// seqlen_v_idx_per_repeat & kPageMask; v_offsets[k0] =
|
||||
// ((page_idx[i_page] << kPageShiftSize) + i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// constexpr index_t kItemOffset = kK1 + k0.value;
|
||||
// index_t i_page;
|
||||
// index_t i_seq;
|
||||
// // i_page = v_coord_i + kItemOffset
|
||||
// // i_seq = i_page & (kPageBlockSize - 1)
|
||||
// // i_page = i_page >> log2(kPageBlockSize)
|
||||
// asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset],
|
||||
// %[v_coord_i]\n\t"
|
||||
// "v_and_b32_e32 %[i_seq], %[kPageMask],
|
||||
// %[i_page]\n\t" "v_lshrrev_b32_e32 %[i_page],
|
||||
// %[kPageShiftSize], %[i_page]\n\t" : [i_page]
|
||||
// "+v"(i_page), [i_seq] "=v"(i_seq) : [v_coord_i]
|
||||
// "v"(v_coord[VPageIndexDim]),
|
||||
// [kItemOffset] "i"(kItemOffset),
|
||||
// [kPageMask] "i"(kPageMask),
|
||||
// [kPageShiftSize] "i"(kPageShiftSize));
|
||||
// v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) *
|
||||
// stride_v;
|
||||
// #endif
|
||||
// });
|
||||
// }
|
||||
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(
|
||||
@@ -625,10 +935,68 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] =
|
||||
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kPageShiftSize,
|
||||
2 * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kIsSglangLayout,
|
||||
false>(page_idx, stride_v, v_coord, v_offsets);
|
||||
|
||||
// if constexpr(kIsSglangLayout)
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// v_offsets[k0] =
|
||||
// page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value]
|
||||
// * stride_v;
|
||||
// });
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
// int32_t seqlen_v_idx_per_repeat =
|
||||
// kK1 * 2 + v_coord[VPageIndexDim] + k0.value;
|
||||
// int32_t i_page = seqlen_v_idx_per_repeat /
|
||||
// kPageBlockSize; int32_t i_seq = seqlen_v_idx_per_repeat
|
||||
// % kPageBlockSize; v_offsets[k0] = (page_idx[i_page] *
|
||||
// kPageBlockSize + i_seq) * stride_k;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) - 1;
|
||||
// int32_t seqlen_v_idx_per_repeat =
|
||||
// v_coord[VPageIndexDim] + 2 * kK1 + k0.value;
|
||||
// int32_t i_page = seqlen_v_idx_per_repeat >>
|
||||
// kPageShiftSize; int32_t i_seq = seqlen_v_idx_per_repeat
|
||||
// & kPageMask; v_offsets[k0] = ((page_idx[i_page] <<
|
||||
// kPageShiftSize) + i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) -
|
||||
// 1; constexpr index_t kItemOffset = 2 * kK1 + k0.value;
|
||||
// index_t i_page;
|
||||
// index_t i_seq;
|
||||
// // i_page = v_coord_i + kItemOffset
|
||||
// // i_seq = i_page & (kPageBlockSize - 1)
|
||||
// // i_page = i_page >> log2(kPageBlockSize)
|
||||
// asm volatile(
|
||||
// "v_add_u32_e32 %[i_page], %[kItemOffset],
|
||||
// %[v_coord_i]\n\t" "v_and_b32_e32 %[i_seq],
|
||||
// %[kPageMask], %[i_page]\n\t" "v_lshrrev_b32_e32
|
||||
// %[i_page], %[kPageShiftSize], %[i_page]\n\t" :
|
||||
// [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) :
|
||||
// [v_coord_i] "v"(v_coord[VPageIndexDim]),
|
||||
// [kItemOffset] "i"(kItemOffset),
|
||||
// [kPageMask] "i"(kPageMask),
|
||||
// [kPageShiftSize] "i"(kPageShiftSize));
|
||||
// v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) +
|
||||
// i_seq) * stride_v;
|
||||
// #endif
|
||||
// });
|
||||
// }
|
||||
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -749,24 +1117,85 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
|
||||
v_coord[VPageIndexDim] + k0.value] *
|
||||
stride_v;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
int32_t seqlen_v_idx_per_repeat = kK1 * 2 + i_k1.value * kK1 +
|
||||
v_coord[VPageIndexDim] + k0.value;
|
||||
int32_t i_page = seqlen_v_idx_per_repeat / page_block_size;
|
||||
int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size;
|
||||
v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v;
|
||||
});
|
||||
}
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kPageShiftSize,
|
||||
(2 + i_k1.value) * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kIsSglangLayout,
|
||||
false>(page_idx, stride_v, v_coord, v_offsets);
|
||||
|
||||
// if constexpr(kIsSglangLayout)
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// v_offsets[k0] = page_idx[kK1 * 2 +
|
||||
// i_k1.value * kK1 +
|
||||
// v_coord[VPageIndexDim]
|
||||
// + k0.value] *
|
||||
// stride_v;
|
||||
// });
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
// int32_t seqlen_v_idx_per_repeat =
|
||||
// kK1 * 2 + i_k1.value * kK1 +
|
||||
// v_coord[VPageIndexDim] + k0.value;
|
||||
// int32_t i_page = seqlen_v_idx_per_repeat
|
||||
// / kPageBlockSize; int32_t i_seq =
|
||||
// seqlen_v_idx_per_repeat % kPageBlockSize;
|
||||
// v_offsets[k0] =
|
||||
// (page_idx[i_page] * kPageBlockSize +
|
||||
// i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
// constexpr index_t kPageMask = (1 <<
|
||||
// kPageShiftSize) - 1; int32_t
|
||||
// seqlen_v_idx_per_repeat =
|
||||
// v_coord[VPageIndexDim] + 2 * kK1 +
|
||||
// i_k1.value * kK1 + k0.value;
|
||||
// int32_t i_page = seqlen_v_idx_per_repeat
|
||||
// >> kPageShiftSize; int32_t i_seq =
|
||||
// seqlen_v_idx_per_repeat & kPageMask;
|
||||
// v_offsets[k0] =
|
||||
// ((page_idx[i_page] << kPageShiftSize)
|
||||
// + i_seq) * stride_v;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
// constexpr index_t kPageMask = (1 <<
|
||||
// kPageShiftSize) - 1; constexpr index_t
|
||||
// kItemOffset =
|
||||
// 2 * kK1 + i_k1.value * kK1 +
|
||||
// k0.value;
|
||||
// index_t i_page;
|
||||
// index_t i_seq;
|
||||
// // i_page = v_coord_i + kItemOffset
|
||||
// // i_seq = i_page & (kPageBlockSize - 1)
|
||||
// // i_page = i_page >>
|
||||
// log2(kPageBlockSize) asm volatile(
|
||||
// "v_add_u32_e32 %[i_page],
|
||||
// %[kItemOffset], %[v_coord_i]\n\t"
|
||||
// "v_and_b32_e32 %[i_seq],
|
||||
// %[kPageMask], %[i_page]\n\t"
|
||||
// "v_lshrrev_b32_e32 %[i_page],
|
||||
// %[kPageShiftSize], %[i_page]\n\t" :
|
||||
// [i_page] "+v"(i_page), [i_seq]
|
||||
// "=v"(i_seq) : [v_coord_i]
|
||||
// "v"(v_coord[VPageIndexDim]),
|
||||
// [kItemOffset] "i"(kItemOffset),
|
||||
// [kPageMask] "i"(kPageMask),
|
||||
// [kPageShiftSize]
|
||||
// "i"(kPageShiftSize));
|
||||
// v_offsets[k0] =
|
||||
// ((page_idx[i_page] << kPageShiftSize)
|
||||
// + i_seq) * stride_v;
|
||||
// #endif
|
||||
// });
|
||||
// }
|
||||
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -807,26 +1236,86 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
page_idx += kN0;
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
page_idx += kN0;
|
||||
}
|
||||
else
|
||||
{
|
||||
page_idx += kN0 / kPageBlockSize;
|
||||
}
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(kIsSglangLayout)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value;
|
||||
int32_t i_page = seqlen_k_idx_per_repeat / page_block_size;
|
||||
int32_t i_seq = seqlen_k_idx_per_repeat % page_block_size;
|
||||
k_offsets[n0] = (page_idx[i_page] * page_block_size + i_seq) * stride_k;
|
||||
});
|
||||
}
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kPageShiftSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kIsSglangLayout,
|
||||
true>(page_idx, stride_k, k_coord, k_offsets);
|
||||
|
||||
// if constexpr(kIsSglangLayout)
|
||||
// {
|
||||
// static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat *
|
||||
// n0.value] * stride_k;
|
||||
// });
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// #if USED_VLLM_PAGE_TABLE_VERSION == 0
|
||||
// int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 /
|
||||
// NRepeat * n0.value; int32_t i_page =
|
||||
// seqlen_k_idx_per_repeat / kPageBlockSize; int32_t i_seq
|
||||
// = seqlen_k_idx_per_repeat % kPageBlockSize; k_offsets[n0]
|
||||
// = (page_idx[i_page] * kPageBlockSize + i_seq) * stride_k;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 1
|
||||
// // constexpr index_t kPageMask = (1 <<
|
||||
// kPageShiftSize) - 1;
|
||||
// // int32_t seqlen_k_idx_per_repeat = k_coord[0] +
|
||||
// n0.value * kN0 / NRepeat;
|
||||
// // int32_t i_page =
|
||||
// seqlen_k_idx_per_repeat >>
|
||||
// // kPageShiftSize; int32_t i_seq =
|
||||
// seqlen_k_idx_per_repeat
|
||||
// // & kPageMask; k_offsets[n0] = ((page_idx[i_page] <<
|
||||
// kPageShiftSize) +
|
||||
// // i_seq) * stride_k;
|
||||
|
||||
// constexpr index_t kItemOffset = n0.value * kN0 / NRepeat
|
||||
// >> kPageShiftSize; int32_t i_page = (k_coord[0] >>
|
||||
// kPageShiftSize) + kItemOffset; k_offsets[n0] =
|
||||
// ((page_idx[i_page] << kPageShiftSize) + k_coord[0]) *
|
||||
// stride_k;
|
||||
// #elif USED_VLLM_PAGE_TABLE_VERSION == 2
|
||||
// constexpr index_t kPageMask = (1 << kPageShiftSize) -
|
||||
// 1; constexpr index_t kItemOffset = n0.value * kN0 /
|
||||
// NRepeat; index_t i_page; index_t i_seq;
|
||||
// // i_page = k_coord[0] + kItemOffset
|
||||
// // i_seq = i_page & (kPageBlockSize - 1)
|
||||
// // i_page = i_page >> log2(kPageBlockSize)
|
||||
// asm volatile(
|
||||
// "v_add_u32_e32 %[i_page], %[kItemOffset],
|
||||
// %[k_coord_0]\n\t" "v_and_b32_e32 %[i_seq],
|
||||
// %[kPageMask], %[i_page]\n\t" "v_lshrrev_b32_e32
|
||||
// %[i_page], %[kPageShiftSize], %[i_page]\n\t" :
|
||||
// [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) :
|
||||
// [k_coord_0] "v"(k_coord[0]),
|
||||
// [kItemOffset] "i"(kItemOffset),
|
||||
// [kPageMask] "i"(kPageMask),
|
||||
// [kPageShiftSize] "i"(kPageShiftSize));
|
||||
// k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) +
|
||||
// i_seq) * stride_k;
|
||||
// #endif
|
||||
// });
|
||||
// }
|
||||
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
|
||||
Reference in New Issue
Block a user