diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 91364fb1d6..5e8a25598a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -448,14 +448,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 56a77ac7dc..47b994e110 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -401,7 +401,7 @@ bool run(const ck_tile::ArgParser& arg_parser) cache_seqlen_ks.begin(), [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); -#if 0 +#if 1 // clang-format off std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; @@ -481,12 +481,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); // std::vector page_idx_host(seqstart_k_host.back(), 0); - ck_tile::HostTensor page_idx_host({seqstart_k_host.back()}); - // std::iota(page_idx_host.begin(), page_idx_host.end(), 0); - iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0); - // for (int i = 0; i < page_idx_host.get_element_space_size(); i++) { - // page_idx_host(i) = (i + 19) % page_idx_host.size(); - // } + ck_tile::HostTensor page_idx_host({seqstart_k_host.back() + 65536}); + std::iota(page_idx_host.begin(), page_idx_host.end(), 0); + iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end() - 65536, 0); page_idx_host.savetxt("page_idx_host.txt", "int"); using TypeConfig = FmhaFwdTypeConfig; @@ -1023,6 +1020,8 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); args.page_idx_ptr = (mode == mode_enum::group ? page_idx.GetDeviceBuffer() : nullptr); + args.page_num = + (mode == mode_enum::group ? shape_seqlen_k : 0); args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 1f02e9f729..b3897a9c90 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -131,6 +131,7 @@ struct fmha_fwd_args seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr const void* page_idx_ptr; + ck_tile::index_t page_num; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -328,6 +329,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_k_ptr, args.page_idx_ptr, + args.page_num, args.hdim_q, args.hdim_v, args.nhead_q, diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 8c74d9c1eb..280ffd8c5f 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -338,7 +338,6 @@ struct tile_scatter_gather constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_gather = idx_ys_start[number{}]; const auto page_offset = page_idx_[idx_gather]; - // read from bottom tensor const vector_t vec_value = get_bottom_tensor_view().template get_vectorized_elements( diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 46a013f5db..00366c50c0 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -266,6 +266,7 @@ struct FmhaFwdKernel const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; const int32_t* page_idx; + int32_t page_num; }; using Kargs = std::conditional_t; @@ -598,6 +599,7 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_k_ptr, const void* page_idx_ptr, + ck_tile::index_t page_num, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -657,7 +659,8 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr), - reinterpret_cast(page_idx_ptr)}; + reinterpret_cast(page_idx_ptr), + page_num}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -724,6 +727,7 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_k_ptr, const void* page_idx_ptr, + ck_tile::index_t page_num, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -763,6 +767,7 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_k_ptr, page_idx_ptr, + page_num, hdim_q, hdim_v, num_head_q, @@ -805,6 +810,7 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_k_ptr, const void* page_idx_ptr, + ck_tile::index_t page_num, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -844,6 +850,7 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_k_ptr, page_idx_ptr, + page_num, hdim_q, hdim_v, num_head_q, @@ -966,7 +973,7 @@ struct FmhaFwdKernel long_index_t batch_offset_q = 0; // long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; + // long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; long_index_t batch_offset_randval = 0; long_index_t batch_offset_lse = 0; @@ -980,16 +987,19 @@ struct FmhaFwdKernel batch_offset_q = query_start * kargs.stride_q; // batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } + // if constexpr(std::is_same_v) + // { + // batch_offset_v = key_start * kargs.stride_v; + // } + // else + // { + // batch_offset_v = key_start; + // } kargs.page_idx += key_start; + // if(threadIdx.x==0){ + // printf("\nbid %d %d page id %d pagev %d\n", blockIdx.z, i_batch, key_start, kargs.page_idx[0]); + // } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias + key_start; @@ -1050,17 +1060,17 @@ struct FmhaFwdKernel const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q; - // const KDataType* k_ptr = - // reinterpret_cast(kargs.k_ptr) + - // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - // batch_offset_k; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k; + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;// + + // batch_offset_k; + // const KDataType* k_ptr = + // reinterpret_cast(kargs.k_ptr) + + // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;// + + // batch_offset_v; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; @@ -1091,7 +1101,7 @@ struct FmhaFwdKernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.page_num, kargs.hdim_q), make_tuple(kargs.stride_k, 1), number{}, number<1>{}); @@ -1107,7 +1117,7 @@ struct FmhaFwdKernel { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.page_num, kargs.hdim_v), make_tuple(kargs.stride_v, 1), number{}, number<1>{}); @@ -1115,7 +1125,7 @@ struct FmhaFwdKernel const auto v_dram_transposed = transform_tensor_view(v_dram_naive, make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), + make_pass_through_transform(kargs.page_num)), make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 75364f4138..4b3190838f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -10,7 +10,6 @@ #include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_scatter_gather.hpp" -// #include "ck_tile/core/tensor/tile_scatter_gather_debug.hpp" namespace ck_tile { @@ -315,6 +314,11 @@ struct BlockFmhaPipelineQRKSVS k_dram_block_window.get_window_origin(), k_dist, k_offsets); // K DRAM tile window for + // auto k_dram_window = make_tile_window_debug( + // k_dram_block_window.get_bottom_tensor_view(), + // k_dram_block_window.get_window_lengths(), + // k_dram_block_window.get_window_origin(), + // k_dist); // K DRAM tile window for auto k_block_tile = load_tile(k_dram_window); { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 97dd2ec710..cb7abb8b89 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -67,7 +67,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ();