diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp index e0491fd303..f9bb6519ce 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp @@ -715,18 +715,19 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel // sequence{}); // }(); const auto k_dram = [&]() { - return make_naive_tensor_view( + const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.num_total_pages, kargs.hdim_q / 8, 16, 8), + make_tuple(kargs.num_total_pages / 16, kargs.hdim_q / 8, 16, 8), make_tuple(kargs.hdim_q * 16, 256, 8, 1), number{}, number<1>{}); - // constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; - // return pad_tensor_view( - // k_dram_naive, - // make_tuple(number{}, number{}), - // sequence{}); + return transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages / 16, 16)), + make_merge_transform(make_tuple(kargs.hdim_q / 8, 8))), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); }(); const auto v_dram = [&]() { if constexpr(std::is_same_v) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp index 8b74930b90..3ea24a83f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +// #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" namespace ck_tile { @@ -246,10 +246,33 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy tuple, sequence<1, 0>>, sequence<1, 2>, sequence<0, 1>>{}); + + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } template - CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemmPreshuffled() { using GemmProblem = BlockGemmProblem