diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 00e0a16536..1f0d73d950 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1046,6 +1046,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); args.batch_stride_block_table = batch_stride_block_table; args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration args.cache_batch_idx = (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 704453baa4..8a821b9177 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -165,6 +165,8 @@ struct fmha_fwd_splitkv_args void* block_table_ptr; ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. const void* cache_batch_idx; @@ -173,12 +175,21 @@ struct fmha_fwd_splitkv_args // seqlen_k = kargs.seqlen_k // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // // batch mode (kvcache): // seqlen_q = kargs.seqlen_q // seqlen_k = kargs.seqlen_k_ptr[b] // group mode (kvcache): // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; @@ -395,6 +406,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.block_table_ptr, args.batch_stride_block_table, args.page_block_size, + args.is_gappy, args.scale_s, args.scale_p, args.stride_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 3c4e02d08b..dcb671d81e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -172,13 +172,18 @@ struct FmhaFwdSplitKVKernel float scale_p; }; - struct PageBlockTableKargs + struct CommonPageBlockTableKargs { const int32_t* block_table_ptr; ck_tile::index_t batch_stride_block_table; ck_tile::index_t page_block_size; }; + struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs + { + bool is_gappy = false; + }; + struct CacheBatchIdxKargs { const int32_t* cache_batch_idx; @@ -193,7 +198,7 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t + std::conditional_t { const int32_t* seqlen_k_ptr; @@ -215,7 +220,7 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -375,6 +380,7 @@ struct FmhaFwdSplitKVKernel const void* block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, + bool is_gappy, float scale_s, float scale_p, ck_tile::index_t stride_q, @@ -461,6 +467,7 @@ struct FmhaFwdSplitKVKernel kargs.block_table_ptr = reinterpret_cast(block_table_ptr); kargs.batch_stride_block_table = batch_stride_block_table; kargs.page_block_size = page_block_size; + kargs.is_gappy = is_gappy; } return kargs; @@ -495,11 +502,13 @@ struct FmhaFwdSplitKVKernel const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; + long_index_t batch_offset_k = 0; // unused for paged-kvcache + long_index_t batch_offset_v = 0; // unused for paged-kvcache long_index_t batch_offset_bias = 0; long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_o_acc = 0; + index_t kv_l2p_offset = + 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache if constexpr(kIsGroupMode) { @@ -508,22 +517,14 @@ struct FmhaFwdSplitKVKernel const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; batch_offset_q = query_start * kargs.stride_q; - if constexpr(kIsPagedKV) + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) { - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -551,6 +552,15 @@ struct FmhaFwdSplitKVKernel { kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } + + if constexpr(kIsPagedKV) + { + if(kargs.is_gappy) + { + // seqstart_k_ptr has different meaning in this case + kv_l2p_offset = kargs.seqstart_k_ptr[i_batch]; + } + } } else { @@ -703,7 +713,7 @@ struct FmhaFwdSplitKVKernel reinterpret_cast(kargs.block_table_ptr) + i_batch_ * kargs.batch_stride_block_table; const index_t num_blocks = - integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); const long_index_t fixed_offset = static_cast(i_nhead_ / kargs.nhead_ratio_qk) * @@ -718,7 +728,8 @@ struct FmhaFwdSplitKVKernel kargs.page_block_size, k_dram, make_k_dram(nullptr, - kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size)); + (kv_l2p_offset + kargs.seqlen_k) - + (num_blocks - 1) * kargs.page_block_size)); } else { @@ -733,7 +744,7 @@ struct FmhaFwdSplitKVKernel reinterpret_cast(kargs.block_table_ptr) + i_batch_ * kargs.batch_stride_block_table; const index_t num_blocks = - integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); const long_index_t fixed_offset = static_cast(i_nhead_ / kargs.nhead_ratio_qk) * @@ -748,7 +759,8 @@ struct FmhaFwdSplitKVKernel kargs.page_block_size, v_dram, make_v_dram(nullptr, - kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size)); + (kv_l2p_offset + kargs.seqlen_k) - + (num_blocks - 1) * kargs.page_block_size)); } else { @@ -896,6 +908,7 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + kv_l2p_offset, smem_ptr); } else @@ -912,6 +925,7 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + kv_l2p_offset, smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp index 675a31019e..5a52fa0f67 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -18,11 +18,11 @@ struct FmhaFwdSplitKVTilePartitioner static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, - ck_tile::index_t max_seqlen_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_splits) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 4e8d8694d7..04aa85644d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -143,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { static_assert( @@ -211,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + const auto q_origin = q_dram_window.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) { - const index_t original_num_total_loop = - integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - if(original_num_total_loop <= 0) + const index_t logical_num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(logical_num_total_loop <= 0) { if constexpr(kStoreLSE) { @@ -239,33 +240,41 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } - // make sure the first tile is completely located in page-block - const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] { - if constexpr(kIsPagedKV) - { - return kN0 * integer_divide_floor(seqlen_k_start_, kN0); - } - else - { - return seqlen_k_start_; - } - }(); + const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; + const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = + [&, physical_seqlen_k_start_ = physical_seqlen_k_start] { + if constexpr(kIsPagedKV) + { + return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0); + } + else + { + return physical_seqlen_k_start_; + } + }(); const index_t num_total_loop = - integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0}); + k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), + logical_seqlen_k_start - (physical_seqlen_k_start - + aligned_physical_seqlen_k_start)}, // M/N Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, adjusted_seqlen_k_start}, // TODO: hdim split? + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -379,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS constexpr auto i_j_idx = make_tuple(idx0, idx1); s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); + // position_encoding accept only logical coordinates, do conversion here + position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset); }); }); } @@ -397,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { const auto k_origin = k_page_block_navigator.to_global_window_origin( i_page_block_k, k_dram_block_window.get_window_origin()); - set_tile_if(s_acc, - -numeric::infinity(), - [&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end]( - auto tile_idx) { - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - if constexpr(kIsPagedKV) - { - return col < seqlen_k_start_ || seqlen_k_end_ <= col; - } - else - { - return seqlen_k_end_ <= col; - } - }); + set_tile_if( + s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + if constexpr(kIsPagedKV) + { + return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; + } + else + { + return physical_seqlen_k_end_ <= col; + } + }); } if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_page_block_navigator.to_global_window_origin( i_page_block_k, k_dram_block_window.get_window_origin()); + // mask accept only logical coordinates, do conversion here bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), + k_origin.at(number<0>{}) - kv_l2p_offset, number{}, number{}); if(need_perpixel_check) @@ -428,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return mask.IsOutOfBound(row, col - kv_l2p_offset); }); } } @@ -659,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { return operator()(q_dram_block_window_tmp, @@ -681,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS mask, position_encoding, scale_s, + kv_l2p_offset, smem_ptr); } };