This commit is contained in:
coderfeli
2025-06-06 01:55:23 +00:00
parent da51fd4959
commit f69215b3ec
2 changed files with 33 additions and 9 deletions

View File

@@ -715,18 +715,19 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
// sequence<kPadSeqLenK_, kPadHeadDimQ>{});
// }();
const auto k_dram = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages, kargs.hdim_q / 8, 16, 8),
make_tuple(kargs.num_total_pages / 16, kargs.hdim_q / 8, 16, 8),
make_tuple(kargs.hdim_q * 16, 256, 8, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
// constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
// return pad_tensor_view(
// k_dram_naive,
// make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
// sequence<kPadSeqLenK_, kPadHeadDimQ>{});
return transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages / 16, 16)),
make_merge_transform(make_tuple(kargs.hdim_q / 8, 8))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)

View File

@@ -7,7 +7,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
// #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
namespace ck_tile {
@@ -246,10 +246,33 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemmPreshuffled()
{
using GemmProblem =
BlockGemmProblem<typename Problem::QDataType,