From 8fb567c286dcbec2225ff1a649e65a26f5b65253 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 26 Jun 2024 17:00:07 +0000 Subject: [PATCH] Fix vnew append errro --- example/ck_tile/01_fmha/fmha_fwd.cpp | 60 +++++++++++++------ example/ck_tile/01_fmha/fmha_fwd.hpp | 2 +- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 43 +++++++++++-- .../block_fmha_fwd_appendkv_pipeline.hpp | 31 ++++++++-- ...a_fwd_appendkv_pipeline_default_policy.hpp | 59 ++++++++++++++++++ 5 files changed, 165 insertions(+), 30 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 075b1c93a8..0eedac3648 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -518,8 +518,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else - return i_perm ? (shape_seqlen_k - seqlen_knew) - : nhead_k * (shape_seqlen_k - seqlen_knew); + return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; }(); const ck_tile::index_t stride_vnew = [&]() { if(is_v_rowmajor) @@ -528,16 +527,14 @@ bool run(const ck_tile::ArgParser& arg_parser) return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = - (i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) - return i_perm ? (shape_seqlen_k - seqlen_knew) * hdim_v : hdim_v; + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; else - return i_perm ? hdim_v * (shape_seqlen_k - seqlen_knew) - : (shape_seqlen_k - seqlen_knew); + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; }(); const ck_tile::index_t nhead_stride_vnew = [&]() { if(is_v_rowmajor) @@ -546,12 +543,10 @@ bool run(const ck_tile::ArgParser& arg_parser) return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = - (nhead_k * (shape_seqlen_k - seqlen_knew) * hdim_q); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); - const ck_tile::index_t batch_stride_v = - (nhead_k * hdim_v * (shape_seqlen_k - seqlen_knew)); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(), @@ -729,16 +724,43 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush; - k_buf.FromDevice(k_host.data()); - for(int row = 0; row < shape_seqlen_k; ++row) +#if defined(ENABLE_HOST_DEBUG_PRINT) + if(!do_validation) { - printf("[POYENC][HOST] k_host[%3d] = ", row); - for(int col = 0; col < hdim_q; ++col) +#if 0 + k_buf.FromDevice(k_host.data()); + for(int row = 0; row < shape_seqlen_k; ++row) { - printf("%11.7f", ck_tile::type_convert(k_host(0, 0, row, col))); + printf("[POYENC][HOST] k_host[%3d] = ", row); + for(int col = 0; col < hdim_q; ++col) + { + printf("%11.7f", ck_tile::type_convert(k_host(0, 0, row, col))); + } + printf("\n"); } - printf("\n"); +#endif + +#if 1 + v_buf.FromDevice(v_host.data()); + for(int row = 0; row < shape_seqlen_k; ++row) + { + printf("[POYENC][HOST] v_host[%3d] = ", row); + for(int col = 0; col < hdim_v; ++col) + { + if(vlayout == "r") + { + printf("%11.7f", ck_tile::type_convert(v_host(0, 0, row, col))); + } + else + { + printf("%11.7f", ck_tile::type_convert(v_host(0, 0, col, row))); + } + } + printf("\n"); + } +#endif } +#endif if(!do_validation) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 438a076604..856e936434 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -344,7 +344,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) } }(); - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v); printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n", static_cast(grids.x), static_cast(grids.y), diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 6ab5c8c96a..1efe6e115d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -270,6 +270,28 @@ struct FmhaFwdAppendKVKernel const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk); // const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); +#if defined(ENABLE_KERNEL_DEBUG_PRINT) +#define PRINTF(expr) printf("[POYENC][DEVICE] " #expr ": %2d\n", (expr)); + if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID) + { + PRINTF(kargs.stride_k); + PRINTF(kargs.nhead_stride_k); + PRINTF(kargs.batch_stride_k); + + PRINTF(kargs.stride_knew); + PRINTF(kargs.nhead_stride_knew); + PRINTF(kargs.batch_stride_knew); + + PRINTF(kargs.stride_v); + PRINTF(kargs.nhead_stride_v); + PRINTF(kargs.batch_stride_v); + + PRINTF(kargs.stride_vnew); + PRINTF(kargs.nhead_stride_vnew); + PRINTF(kargs.batch_stride_vnew); + } +#endif + long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; long_index_t batch_offset_knew = @@ -477,13 +499,26 @@ struct FmhaFwdAppendKVKernel auto v_dram_window = make_tile_window( v_dram, - make_tuple(number{}, number{}), - {kargs.seqlen_k - kargs.seqlen_knew, 0}); + make_tuple(number{}, number{}), + {0, kargs.seqlen_k - kargs.seqlen_knew}); auto vnew_dram_window = make_tile_window( vnew_dram, - make_tuple(number{}, number{}), - {i_sk, 0}); + make_tuple(number{}, number{}), + {0, i_sk}); + +#if defined(ENABLE_KERNEL_DEBUG_PRINT) + if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID) + { + printf("[POYENC][DEVICE] kargs.seqlen_k - kargs.seqlen_knew: %d\n", + kargs.seqlen_k - kargs.seqlen_knew); + printf("[POYENC][DEVICE] i_sk: %d\n", i_sk); + printf("[POYENC][DEVICE] v_dram.get_length(0): %d\n", + v_dram.get_tensor_descriptor().get_length(number<0>{})); + printf("[POYENC][DEVICE] v_dram.get_length(1): %d\n", + v_dram.get_tensor_descriptor().get_length(number<1>{})); + } +#endif FmhaPipeline{}(q_dram_window, k_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 82e5951f6a..5892e0abfe 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -121,26 +121,45 @@ struct BlockFmhaFwdAppendKVPipeline Policy::template MakeKnewDramTileDistribution()); auto knew_tile = load_tile(knew_dram_window); - if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0) + /// TODO: apply RoPE on knew_tile here + store_tile(k_dram_block_window_tmp, knew_tile); + + auto vnew_dram_block_window = + make_tile_window(vnew_dram_block_window_tmp.get_bottom_tensor_view(), + vnew_dram_block_window_tmp.get_window_lengths(), + {0, 0}); + + auto vnew_dram_window = + make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(), + vnew_dram_block_window.get_window_lengths(), + vnew_dram_block_window.get_window_origin(), + Policy::template MakeVnewDramTileDistribution()); + + auto vnew_tile = load_tile(vnew_dram_window); + +#if defined(ENABLE_PIPELINE_DEBUG_PRINT) + if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID) { - constexpr auto spans = decltype(knew_tile)::get_distributed_spans(); + printf("[POYENC][DEVICE] tid: %d\n", TID); + constexpr auto spans = decltype(vnew_tile)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - knew_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + vnew_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = tile_idx.at(number<0>{}); const auto col = tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - printf("[POYENC][DEVICE] knew_tile(%2d,%2d): %11.7f\n", + printf("[POYENC][DEVICE] vnew_tile(%2d,%2d): %11.7f\n", row, col, - type_convert(knew_tile(i_j_idx))); + type_convert(vnew_tile(i_j_idx))); }); }); } - store_tile(k_dram_block_window_tmp, knew_tile); +#endif + store_tile(v_dram_block_window_tmp, vnew_tile); } template , sequence<0, 1>>{}); } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVnewDramTileDistribution() + { + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + static_assert(!std::is_same_v); + if constexpr(std::is_same_v) + { + constexpr index_t kNPerBlock = Problem::kTileSizeDv; + constexpr index_t kKPerBlock = Problem::kTileSizeSk; + + constexpr index_t K1 = GetAlignmentV(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + using VDataType = remove_cvref_t; + + constexpr index_t kNPerBlock = Problem::kTileSizeDv; + constexpr index_t kKPerBlock = Problem::kTileSizeSk; + + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } }; } // namespace ck_tile