From 9beec805ba93ae10d0b615eca6d311dabfe2b093 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Tue, 8 Jul 2025 02:46:27 +0000 Subject: [PATCH] mha_batch_prefill support page_block_size=16 for vllm --- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 34 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 683 +++++++++++++++--- 2 files changed, 604 insertions(+), 113 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 5ba55b5229..b3a56207e3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -709,7 +709,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; - const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; + // const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; + kargs.seqlen_k = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; #if 0 // we assume page_block_size=1 for now const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; #endif @@ -720,7 +721,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel batch_offset_q = query_start * kargs.stride_q; - if constexpr(kIsSglangLayout) { kargs.kv_page_indices += kargs.kv_indptr[i_batch]; @@ -762,11 +762,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } } -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + // #if 0 // we assume page_block_size=1 for now + // kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + + // last_page_len; + // #else + // kargs.seqlen_k = num_page_blocks; + // #endif } else { @@ -789,11 +790,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + // #if 0 // we assume page_block_size=1 for now + // kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + + // last_page_len; + // #else + // kargs.seqlen_k = num_page_blocks; + // #endif } // for simplicity, batch stride we just modify the pointer @@ -861,10 +863,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const auto v_dram_transposed = transform_tensor_view( v_dram_naive, - make_tuple( - make_pass_through_transform(kargs.hdim_v), - // make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), - make_pass_through_transform(kargs.num_total_pages)), + make_tuple(make_pass_through_transform(kargs.hdim_v), + // make_pass_through_transform(kargs.num_total_pages * + // kargs.page_block_size)), + make_pass_through_transform(kargs.num_total_pages)), make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 0bf6c96bdf..91b062cf72 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -11,8 +11,178 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" +// #define USED_VLLM_PAGE_TABLE_VERSION 0 +// #define USED_VLLM_PAGE_TABLE_VERSION 1 +// #define USED_VLLM_PAGE_TABLE_VERSION 2 +#define USED_VLLM_PAGE_TABLE_VERSION 3 + namespace ck_tile { +union DoubleIndext +{ + index_t idx2[2]; + long_index_t ldx; +}; + +template +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec, + const index_t& stride_kv, + const CoordVecType& coord_vec, + OffsetVecType& kv_offset_vec) +{ + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + // if(blockIdx.x + blockIdx.y + blockIdx.z + threadIdx.y + threadIdx.z == 0 && threadIdx.x == 0) + // if(blockIdx.x + blockIdx.y + threadIdx.y + threadIdx.z == 0 && blockIdx.z == 3 && + // threadIdx.x == 102) + // { + // printf("kIsSglangLayout=%d\n", kIsSglangLayout); + // if constexpr(kIsKcache) + // { + // printf("k_id: blkz=%d, thr_idx=%d, thr_coord_st=%d\n", + // blockIdx.z, + // threadIdx.x, + // thread_coord_start); + // } + // else + // { + // printf("v_id: blkz=%d, thr_idx=%d, thr_coord_st=%d\n", + // blockIdx.z, + // threadIdx.x, + // thread_coord_start); + // } + // } + + if constexpr(kIsSglangLayout) + { + static_for<0, kLoopCount, 1>{}([&](auto k0) { + kv_offset_vec[k0] = + page_vec[thread_coord_start + kLoopStart + kLoopStride * k0.value] * stride_kv; + }); + } + else + { +#if USED_VLLM_PAGE_TABLE_VERSION == 3 + constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + if constexpr(kIsKcache) + { + // for k_offset_vec + constexpr index_t kPageStride = kLoopStride >> kPageShiftSize; + // constexpr array kPageIdArray = []() { + // array arr; + // static_for<0, kLoopCount, 1>{}([&](auto k0) { + // constexpr index_t kPageId = kPageStride * k0.value; + // arr[k0] = kPageId; + // }); + // return arr; + // }(); + static_for<0, kLoopCount, 1>{}([&](auto k0) { + // constexpr index_t kPageId = kPageIdArray[k0]; + constexpr index_t kPageId = kPageStride * k0.value; + const index_t page_offset = + (thread_coord_start + kLoopStride * k0.value) & kPageMask; + kv_offset_vec[k0] = + ((page_vec[kPageId] << kPageShiftSize) + page_offset) * stride_kv; + }); + } + else + { + // for v_offset_vec + const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); + const index_t lane0_page_id = (lane0_start + kLoopStart) >> kPageShiftSize; + const index_t page_loc = page_vec[lane0_page_id] << kPageShiftSize; + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t page_offset = + (thread_coord_start + kLoopStart + k0.value) & kPageMask; + kv_offset_vec[k0] = (page_loc + page_offset) * stride_kv; + }); + } + +#else + + index_t i_page; + index_t i_seq; + static_for<0, kLoopCount, 1>{}([&](auto k0) { +#if USED_VLLM_PAGE_TABLE_VERSION == 0 + int32_t seqlen_v_idx_per_repeat = + thread_coord_start + kLoopStart + kLoopStride * k0.value; + i_page = seqlen_v_idx_per_repeat / kPageBlockSize; + i_seq = seqlen_v_idx_per_repeat % kPageBlockSize; + kv_offset_vec[k0] = (page_vec[i_page] * kPageBlockSize + i_seq) * stride_kv; + +#elif USED_VLLM_PAGE_TABLE_VERSION == 1 + if constexpr(kIsKcache) + { + // for k_offset_vec + constexpr index_t kItemOffset = + (kLoopStart + kLoopStride * k0.value) >> kPageShiftSize; + i_page = (thread_coord_start >> kPageShiftSize) + kItemOffset; + kv_offset_vec[k0] = + ((page_vec[i_page] << kPageShiftSize) + thread_coord_start) * stride_kv; + } + else + { + // for v_offset_vec + constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + int32_t seqlen_v_idx_per_repeat = + thread_coord_start + kLoopStart + kLoopStride * k0.value; + i_page = seqlen_v_idx_per_repeat >> kPageShiftSize; + i_seq = seqlen_v_idx_per_repeat & kPageMask; + kv_offset_vec[k0] = ((page_vec[i_page] << kPageShiftSize) + i_seq) * stride_kv; + } + +#elif USED_VLLM_PAGE_TABLE_VERSION == 2 + if constexpr(kIsKcache) + { + // for k_offset_vec + // index_t i_page; + constexpr index_t kItemOffset = kLoopStride * k0.value >> kPageShiftSize; + // i_page = (thread_coord_start >> kPageShiftSize) + kItemOffset; + asm volatile("v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize], + % [thread_coord_start]\n\t " " v_add_u32_e32 % [i_page], + % [kItemOffset], + % [i_page]\n\t " : [i_page] " + v "(i_page) : [thread_coord_start] + "v"(thread_coord_start), + [kItemOffset] "i"(kItemOffset), + [kPageShiftSize] "i"(kPageShiftSize)); + kv_offset_vec[k0] = + ((page_vec[i_page] << kPageShiftSize) + thread_coord_start) * stride_kv; + } + else + { + constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + constexpr index_t kItemOffset = kLoopStart + kLoopStride * k0.value; + // i_page = thread_coord_start + kItemOffset + // i_seq = i_page & (kPageBlockSize - 1) + // i_page = i_page >> log2(kPageBlockSize) + asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], + % [thread_coord_start]\n\t + " + "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t" + "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize], + % [i_page]\n\t " : [i_page] " + + v "(i_page), [i_seq] " = v "(i_seq) + : [thread_coord_start] "v"(thread_coord_start), + [kItemOffset] "i"(kItemOffset), + [kPageMask] "i"(kPageMask), + [kPageShiftSize] "i"(kPageShiftSize)); + kv_offset_vec[k0] = ((page_vec[i_page] << kPageShiftSize) + i_seq) * stride_kv; + } +#endif + }); + +#endif + } +} + // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) template @@ -41,17 +211,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kPageBlockSize = 16; + // static constexpr index_t kPageBlockSize = 1; + + static constexpr index_t kPageShiftSize = 4; + // static constexpr index_t kPageShiftSize = 0; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -65,8 +241,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; - // static constexpr bool kIsSglangLayout = Problem::kIsSglangLayout; static constexpr bool kIsSglangLayout = Problem::kIsSglangLayout; + // static constexpr bool kIsSglangLayout = true; static constexpr bool kIsChunkedPrefill = Problem::kIsChunkedPrefill; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; @@ -126,6 +302,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else + // return 1; return 2; } else if constexpr(kQKHeaddim <= 192) @@ -327,21 +504,71 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using KDstrEncode = typename decltype(k_dist)::DstrEncode; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; - if constexpr(kIsSglangLayout) - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); - } - else - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value; - int32_t i_page = seqlen_k_idx_per_repeat / page_block_size; - int32_t i_seq = seqlen_k_idx_per_repeat % page_block_size; - k_offsets[n0] = (page_idx[i_page] * page_block_size + i_seq) * stride_k; - }); - } + + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kPageShiftSize, + 0, + NRepeat, + kN0 / NRepeat, + kIsSglangLayout, + true>(page_idx, stride_k, k_coord, k_offsets); + + // if constexpr(kIsSglangLayout) + // { + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * + // stride_k; + // }); + // } + // else + // { + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // #if USED_VLLM_PAGE_TABLE_VERSION == 0 + // int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value; + // int32_t i_page = seqlen_k_idx_per_repeat / + // kPageBlockSize; int32_t i_seq = seqlen_k_idx_per_repeat + // % kPageBlockSize; k_offsets[n0] = (page_idx[i_page] * kPageBlockSize + + // i_seq) * stride_k; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 1 + // // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // // int32_t seqlen_k_idx_per_repeat = k_coord[0] + n0.value * kN0 / + // NRepeat; + // // int32_t i_page = seqlen_k_idx_per_repeat >> + // kPageShiftSize; + // // int32_t i_seq = seqlen_k_idx_per_repeat & kPageMask; + // // k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) * + // stride_k; + + // constexpr index_t kItemOffset = n0.value * kN0 / NRepeat >> + // kPageShiftSize; int32_t i_page = (k_coord[0] >> + // kPageShiftSize) + kItemOffset; k_offsets[n0] = ((page_idx[i_page] << + // kPageShiftSize) + k_coord[0]) * stride_k; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 2 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // constexpr index_t kItemOffset = n0.value * kN0 / NRepeat; + // index_t i_page; + // index_t i_seq; + // // i_page = k_coord[0] + kItemOffset + // // i_seq = i_page & (kPageBlockSize - 1) + // // i_page = i_page >> log2(kPageBlockSize) + // asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], %[k_coord_0]\n\t" + // "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t" + // "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize], + // % [i_page]\n\t " : [i_page] " + + // v "(i_page), [i_seq] " = v "(i_seq) + // : [k_coord_0] "v"(k_coord[0]), + // [kItemOffset] "i"(kItemOffset), + // [kPageMask] "i"(kPageMask), + // [kPageShiftSize] "i"(kPageShiftSize)); + // k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) * + // stride_k; + // #endif + // }); + // } + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), @@ -375,22 +602,62 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using VDstrEncode = typename decltype(v_dist)::DstrEncode; constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; statically_indexed_array v_offsets; - (void)stride_k; - if constexpr(kIsSglangLayout) - { - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - } - else - { - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value; - int32_t i_page = seqlen_v_idx_per_repeat / page_block_size; - int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size; - v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v; - }); - } + + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kPageShiftSize, + 0, + V_KRepeat, + 1, + kIsSglangLayout, + false>(page_idx, stride_v, v_coord, v_offsets); + + // // (void)stride_k; + // if constexpr(kIsSglangLayout) + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; + // }); + // } + // else + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // #if USED_VLLM_PAGE_TABLE_VERSION == 0 + // int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value; + // int32_t i_page = seqlen_v_idx_per_repeat / + // kPageBlockSize; int32_t i_seq = seqlen_v_idx_per_repeat + // % kPageBlockSize; v_offsets[k0] = (page_idx[i_page] * kPageBlockSize + + // i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 1 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + k0.value; + // int32_t i_page = seqlen_v_idx_per_repeat >> + // kPageShiftSize; int32_t i_seq = seqlen_v_idx_per_repeat + // & kPageMask; v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + + // i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 2 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // constexpr index_t kItemOffset = k0.value; + // index_t i_page; + // index_t i_seq; + // // i_page = v_coord_i + kItemOffset + // // i_seq = i_page & (kPageBlockSize - 1) + // // i_page = i_page >> log2(kPageBlockSize) + // asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], %[v_coord_i]\n\t" + // "v_and_b32_e32 %[i_seq], %[kPageMask], %[i_page]\n\t" + // "v_lshrrev_b32_e32 %[i_page], %[kPageShiftSize], + // %[i_page]\n\t" : [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) + // : [v_coord_i] "v"(v_coord[VPageIndexDim]), + // [kItemOffset] "i"(kItemOffset), + // [kPageMask] "i"(kPageMask), + // [kPageShiftSize] "i"(kPageShiftSize)); + // v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) * + // stride_v; + // #endif + // }); + // } auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -454,23 +721,66 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto bias_tile = load_tile(bias_dram_window); // load bias tile auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); - if constexpr(kIsSglangLayout) - { - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - } - else - { - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - int32_t seqlen_v_idx_per_repeat = kK1 + v_coord[VPageIndexDim] + k0.value; - int32_t i_page = seqlen_v_idx_per_repeat / page_block_size; - int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size; - v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v; - }); - } - v_dram_window.update_page_idx(v_offsets); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kPageShiftSize, + kK1, + V_KRepeat, + 1, + kIsSglangLayout, + false>(page_idx, stride_v, v_coord, v_offsets); + + // if constexpr(kIsSglangLayout) + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] + // * stride_v; + // }); + // } + // else + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // #if USED_VLLM_PAGE_TABLE_VERSION == 0 + // int32_t seqlen_v_idx_per_repeat = kK1 + v_coord[VPageIndexDim] + + // k0.value; int32_t i_page = + // seqlen_v_idx_per_repeat / kPageBlockSize; int32_t i_seq = + // seqlen_v_idx_per_repeat % kPageBlockSize; v_offsets[k0] = + // (page_idx[i_page] * kPageBlockSize + i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 1 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // int32_t seqlen_v_idx_per_repeat = v_coord[VPageIndexDim] + kK1 + + // k0.value; int32_t i_page = + // seqlen_v_idx_per_repeat >> kPageShiftSize; int32_t i_seq = + // seqlen_v_idx_per_repeat & kPageMask; v_offsets[k0] = + // ((page_idx[i_page] << kPageShiftSize) + i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 2 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // constexpr index_t kItemOffset = kK1 + k0.value; + // index_t i_page; + // index_t i_seq; + // // i_page = v_coord_i + kItemOffset + // // i_seq = i_page & (kPageBlockSize - 1) + // // i_page = i_page >> log2(kPageBlockSize) + // asm volatile("v_add_u32_e32 %[i_page], %[kItemOffset], + // %[v_coord_i]\n\t" + // "v_and_b32_e32 %[i_seq], %[kPageMask], + // %[i_page]\n\t" "v_lshrrev_b32_e32 %[i_page], + // %[kPageShiftSize], %[i_page]\n\t" : [i_page] + // "+v"(i_page), [i_seq] "=v"(i_seq) : [v_coord_i] + // "v"(v_coord[VPageIndexDim]), + // [kItemOffset] "i"(kItemOffset), + // [kPageMask] "i"(kPageMask), + // [kPageShiftSize] "i"(kPageShiftSize)); + // v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + i_seq) * + // stride_v; + // #endif + // }); + // } + + v_dram_window.update_page_idx(v_offsets); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -625,10 +935,68 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = - page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); + + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kPageShiftSize, + 2 * kK1, + V_KRepeat, + 1, + kIsSglangLayout, + false>(page_idx, stride_v, v_coord, v_offsets); + + // if constexpr(kIsSglangLayout) + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // v_offsets[k0] = + // page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] + // * stride_v; + // }); + // } + // else + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // #if USED_VLLM_PAGE_TABLE_VERSION == 0 + // int32_t seqlen_v_idx_per_repeat = + // kK1 * 2 + v_coord[VPageIndexDim] + k0.value; + // int32_t i_page = seqlen_v_idx_per_repeat / + // kPageBlockSize; int32_t i_seq = seqlen_v_idx_per_repeat + // % kPageBlockSize; v_offsets[k0] = (page_idx[i_page] * + // kPageBlockSize + i_seq) * stride_k; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 1 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - 1; + // int32_t seqlen_v_idx_per_repeat = + // v_coord[VPageIndexDim] + 2 * kK1 + k0.value; + // int32_t i_page = seqlen_v_idx_per_repeat >> + // kPageShiftSize; int32_t i_seq = seqlen_v_idx_per_repeat + // & kPageMask; v_offsets[k0] = ((page_idx[i_page] << + // kPageShiftSize) + i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 2 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - + // 1; constexpr index_t kItemOffset = 2 * kK1 + k0.value; + // index_t i_page; + // index_t i_seq; + // // i_page = v_coord_i + kItemOffset + // // i_seq = i_page & (kPageBlockSize - 1) + // // i_page = i_page >> log2(kPageBlockSize) + // asm volatile( + // "v_add_u32_e32 %[i_page], %[kItemOffset], + // %[v_coord_i]\n\t" "v_and_b32_e32 %[i_seq], + // %[kPageMask], %[i_page]\n\t" "v_lshrrev_b32_e32 + // %[i_page], %[kPageShiftSize], %[i_page]\n\t" : + // [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) : + // [v_coord_i] "v"(v_coord[VPageIndexDim]), + // [kItemOffset] "i"(kItemOffset), + // [kPageMask] "i"(kPageMask), + // [kPageShiftSize] "i"(kPageShiftSize)); + // v_offsets[k0] = ((page_idx[i_page] << kPageShiftSize) + + // i_seq) * stride_v; + // #endif + // }); + // } + v_dram_window.update_page_idx(v_offsets); } __builtin_amdgcn_sched_barrier(0); @@ -749,24 +1117,85 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - if constexpr(kIsSglangLayout) - { - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + - v_coord[VPageIndexDim] + k0.value] * - stride_v; - }); - } - else - { - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - int32_t seqlen_v_idx_per_repeat = kK1 * 2 + i_k1.value * kK1 + - v_coord[VPageIndexDim] + k0.value; - int32_t i_page = seqlen_v_idx_per_repeat / page_block_size; - int32_t i_seq = seqlen_v_idx_per_repeat % page_block_size; - v_offsets[k0] = (page_idx[i_page] * page_block_size + i_seq) * stride_v; - }); - } + + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kPageShiftSize, + (2 + i_k1.value) * kK1, + V_KRepeat, + 1, + kIsSglangLayout, + false>(page_idx, stride_v, v_coord, v_offsets); + + // if constexpr(kIsSglangLayout) + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // v_offsets[k0] = page_idx[kK1 * 2 + + // i_k1.value * kK1 + + // v_coord[VPageIndexDim] + // + k0.value] * + // stride_v; + // }); + // } + // else + // { + // static_for<0, V_KRepeat, 1>{}([&](auto k0) { + // #if USED_VLLM_PAGE_TABLE_VERSION == 0 + // int32_t seqlen_v_idx_per_repeat = + // kK1 * 2 + i_k1.value * kK1 + + // v_coord[VPageIndexDim] + k0.value; + // int32_t i_page = seqlen_v_idx_per_repeat + // / kPageBlockSize; int32_t i_seq = + // seqlen_v_idx_per_repeat % kPageBlockSize; + // v_offsets[k0] = + // (page_idx[i_page] * kPageBlockSize + + // i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 1 + // constexpr index_t kPageMask = (1 << + // kPageShiftSize) - 1; int32_t + // seqlen_v_idx_per_repeat = + // v_coord[VPageIndexDim] + 2 * kK1 + + // i_k1.value * kK1 + k0.value; + // int32_t i_page = seqlen_v_idx_per_repeat + // >> kPageShiftSize; int32_t i_seq = + // seqlen_v_idx_per_repeat & kPageMask; + // v_offsets[k0] = + // ((page_idx[i_page] << kPageShiftSize) + // + i_seq) * stride_v; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 2 + // constexpr index_t kPageMask = (1 << + // kPageShiftSize) - 1; constexpr index_t + // kItemOffset = + // 2 * kK1 + i_k1.value * kK1 + + // k0.value; + // index_t i_page; + // index_t i_seq; + // // i_page = v_coord_i + kItemOffset + // // i_seq = i_page & (kPageBlockSize - 1) + // // i_page = i_page >> + // log2(kPageBlockSize) asm volatile( + // "v_add_u32_e32 %[i_page], + // %[kItemOffset], %[v_coord_i]\n\t" + // "v_and_b32_e32 %[i_seq], + // %[kPageMask], %[i_page]\n\t" + // "v_lshrrev_b32_e32 %[i_page], + // %[kPageShiftSize], %[i_page]\n\t" : + // [i_page] "+v"(i_page), [i_seq] + // "=v"(i_seq) : [v_coord_i] + // "v"(v_coord[VPageIndexDim]), + // [kItemOffset] "i"(kItemOffset), + // [kPageMask] "i"(kPageMask), + // [kPageShiftSize] + // "i"(kPageShiftSize)); + // v_offsets[k0] = + // ((page_idx[i_page] << kPageShiftSize) + // + i_seq) * stride_v; + // #endif + // }); + // } + v_dram_window.update_page_idx(v_offsets); } block_sync_lds(); @@ -807,26 +1236,86 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - page_idx += kN0; + if constexpr(kIsSglangLayout) + { + page_idx += kN0; + } + else + { + page_idx += kN0 / kPageBlockSize; + } // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - if constexpr(kIsSglangLayout) - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); - } - else - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / NRepeat * n0.value; - int32_t i_page = seqlen_k_idx_per_repeat / page_block_size; - int32_t i_seq = seqlen_k_idx_per_repeat % page_block_size; - k_offsets[n0] = (page_idx[i_page] * page_block_size + i_seq) * stride_k; - }); - } + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kPageShiftSize, + 0, + NRepeat, + kN0 / NRepeat, + kIsSglangLayout, + true>(page_idx, stride_k, k_coord, k_offsets); + + // if constexpr(kIsSglangLayout) + // { + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * + // n0.value] * stride_k; + // }); + // } + // else + // { + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // #if USED_VLLM_PAGE_TABLE_VERSION == 0 + // int32_t seqlen_k_idx_per_repeat = k_coord[0] + kN0 / + // NRepeat * n0.value; int32_t i_page = + // seqlen_k_idx_per_repeat / kPageBlockSize; int32_t i_seq + // = seqlen_k_idx_per_repeat % kPageBlockSize; k_offsets[n0] + // = (page_idx[i_page] * kPageBlockSize + i_seq) * stride_k; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 1 + // // constexpr index_t kPageMask = (1 << + // kPageShiftSize) - 1; + // // int32_t seqlen_k_idx_per_repeat = k_coord[0] + + // n0.value * kN0 / NRepeat; + // // int32_t i_page = + // seqlen_k_idx_per_repeat >> + // // kPageShiftSize; int32_t i_seq = + // seqlen_k_idx_per_repeat + // // & kPageMask; k_offsets[n0] = ((page_idx[i_page] << + // kPageShiftSize) + + // // i_seq) * stride_k; + + // constexpr index_t kItemOffset = n0.value * kN0 / NRepeat + // >> kPageShiftSize; int32_t i_page = (k_coord[0] >> + // kPageShiftSize) + kItemOffset; k_offsets[n0] = + // ((page_idx[i_page] << kPageShiftSize) + k_coord[0]) * + // stride_k; + // #elif USED_VLLM_PAGE_TABLE_VERSION == 2 + // constexpr index_t kPageMask = (1 << kPageShiftSize) - + // 1; constexpr index_t kItemOffset = n0.value * kN0 / + // NRepeat; index_t i_page; index_t i_seq; + // // i_page = k_coord[0] + kItemOffset + // // i_seq = i_page & (kPageBlockSize - 1) + // // i_page = i_page >> log2(kPageBlockSize) + // asm volatile( + // "v_add_u32_e32 %[i_page], %[kItemOffset], + // %[k_coord_0]\n\t" "v_and_b32_e32 %[i_seq], + // %[kPageMask], %[i_page]\n\t" "v_lshrrev_b32_e32 + // %[i_page], %[kPageShiftSize], %[i_page]\n\t" : + // [i_page] "+v"(i_page), [i_seq] "=v"(i_seq) : + // [k_coord_0] "v"(k_coord[0]), + // [kItemOffset] "i"(kItemOffset), + // [kPageMask] "i"(kPageMask), + // [kPageShiftSize] "i"(kPageShiftSize)); + // k_offsets[n0] = ((page_idx[i_page] << kPageShiftSize) + + // i_seq) * stride_k; + // #endif + // }); + // } + k_dram_window.update_page_idx(k_offsets); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{}))