diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b28d255697..f6949ee9af 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -308,6 +308,14 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(use_cache_batch_idx) + { + std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } +#endif if(0 < page_block_size && use_cache_batch_idx) { std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " @@ -317,7 +325,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } auto mode = static_cast(arg_parser.get_uint32("mode")); - if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) + if((0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size) && mode != mode_enum::batch) { std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; mode = mode_enum::batch; @@ -531,7 +539,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } #if CK_TILE_FMHA_FWD_SPLITKV_API - if(0 < p_drop && (1 < num_splits || 0 < page_block_size)) + if(0 < p_drop && (1 < num_splits || use_cache_batch_idx || 0 < page_block_size)) { std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" << std::endl; @@ -780,6 +788,10 @@ bool run(const ck_tile::ArgParser& arg_parser) { std::cout << ", page_block_size:" << page_block_size; } + if(use_cache_batch_idx) + { + std::cout << ", cache_batch_idx:" << use_cache_batch_idx; + } #endif std::cout << std::flush; @@ -1097,7 +1109,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; // adjust matrix index according to the mode - const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = + (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t key_offset = (mode == mode_enum::batch @@ -1149,8 +1163,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } else #endif { - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); } #if CK_TILE_FMHA_FWD_APPENDKV_API @@ -1220,14 +1234,14 @@ bool run(const ck_tile::ArgParser& arg_parser) { if(is_v_rowmajor) { // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); } else { - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[1], i[2] + key_offset); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); } } diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 3d7ed70998..52c94489b3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -32,10 +32,12 @@ done run_fp16_bf16_tests() { local NUM_SPLITS=(1) local PAGE_BLOCK_SIZE=(0) + local CACHE_BATCH_IDX=(0) if [ $TEST_SPLITKV -eq 1 ] ; then NUM_SPLITS+=(2 3) PAGE_BLOCK_SIZE+=(128) + CACHE_BATCH_IDX+=(1) fi for prec in "fp16" "bf16" ; do @@ -48,20 +50,22 @@ run_fp16_bf16_tests() { for p_drop in 0.0 0.2 ; do for num_splits in "${NUM_SPLITS[@]}" ; do for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do + for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done + done ; } run_fp8_tests() { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index cad662e7a7..7cfb27c926 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -249,14 +249,26 @@ struct FmhaFwdAppendKVKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); + const index_t i_cache_batch = [&, i_batch_ = i_batch] { + if constexpr(kIsPagedKV) + { + return i_batch_; + } + else + { + return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_] + : i_batch_); + } + }(); + const long_index_t batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; const long_index_t batch_offset_k = - static_cast(i_batch) * kargs.batch_stride_k; + static_cast(i_cache_batch) * kargs.batch_stride_k; const long_index_t batch_offset_knew = static_cast(i_batch) * kargs.batch_stride_knew; const long_index_t batch_offset_v = - static_cast(i_batch) * kargs.batch_stride_v; + static_cast(i_cache_batch) * kargs.batch_stride_v; const long_index_t batch_offset_vnew = static_cast(i_batch) * kargs.batch_stride_vnew; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 540c1e55ef..71ddf7b1f5 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -529,9 +529,21 @@ struct FmhaFwdSplitKVKernel } else { + const index_t i_cache_batch = [&, i_batch_ = i_batch] { + if constexpr(kIsPagedKV) + { + return i_batch_; + } + else + { + return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_] + : i_batch_); + } + }(); + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) {