From 94b6430489a7be3611234322a9e1b88ebcf0564f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 17 Jul 2025 10:06:09 +0000 Subject: [PATCH] temp save --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 3 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 7 +- include/ck_tile/core/numeric/bfloat16.hpp | 2 +- include/ck_tile/core/numeric/pk_fp4.hpp | 2 +- .../fmha/kernel/fmha_fwd_decode_kernel.hpp | 31 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 10 +- ...lock_fmha_fwd_decode_pipeline_qr_ks_vs.hpp | 424 ++++++------------ ...ha_fwd_decode_pipeline_qr_ks_vs_policy.hpp | 136 ++++++ include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 2 +- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 + 11 files changed, 298 insertions(+), 325 deletions(-) mode change 100755 => 100644 example/ck_tile/01_fmha/fmha_fwd.cpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 36807ec9a0..30e9163812 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -30,7 +30,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api ${FMHA_FWD_APIS} - # --filter fmha_fwd... + --filter fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_pddv_nlogits_nbias_nmask_nlse_nsquant_npagedkv@fmha_fwd_decode_d64_bf16_batch_b16x32x64x64x32x64_r1x1x1_r1x1x1_w16x16x32_w16x16x32_decode_qr_vr_pddv_nlogits_nbias_nmask_nlse_nsquant_npagedkv ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp old mode 100755 new mode 100644 index c66f12fafb..814c6fd055 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1042,7 +1042,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } } - else if constexpr(std::is_same_v> || std::is_same_v>) + else if constexpr(std::is_same_v> || + std::is_same_v>) { args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index bd5b5a27c1..4b0df4bf0c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -995,7 +995,7 @@ auto fmha_fwd_decode_create_kargs_and_grids(fmha_fwd_decode_args args) args.v_ptr, args.bias_ptr, args.lse_acc_ptr, - // args.o_acc_ptr, + // args.o_acc_ptr, args.o_ptr, // hardcoding args.batch, args.seqlen_q, @@ -1625,10 +1625,7 @@ struct fmha_fwd_decode_traits bool do_fp8_static_quant; // TODO: padding check is inside this api }; -float fmha_fwd_decode(fmha_fwd_decode_traits, - fmha_fwd_decode_args, - const ck_tile::stream_config&); - +float fmha_fwd_decode(fmha_fwd_decode_traits, fmha_fwd_decode_args, const ck_tile::stream_config&); struct fmha_fwd_appendkv_traits { diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 79da65e006..b0b85a8c5a 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -287,7 +287,7 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { -#if defined (__gfx950__) +#if defined(__gfx950__) return static_cast(f); #else return bit_cast(float_to_bf16_raw(f, constant{})); diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index b7dca9dd0a..fc61fd2773 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -21,7 +21,7 @@ namespace ck_tile { using fp32_t = float; using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); -using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp index 6ec5972920..9e21ed73ce 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp @@ -757,7 +757,9 @@ struct FmhaFwdDecodeKernel const auto make_k_dram = [&](const KDataType* data, index_t height) { // We don't expect K data reuse among different blocks in decode case. - const auto k_dram_naive = make_naive_tensor_view( + const auto k_dram_naive = make_naive_tensor_view( data, // will update this pointer if using paged-kvcache make_tuple(height, kargs.hdim_q), make_tuple(kargs.stride_k, 1), @@ -784,12 +786,15 @@ struct FmhaFwdDecodeKernel if constexpr(std::is_same_v) { // We don't expect V data reuse among different blocks in decode case. - const auto v_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(length, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, - number<1>{}); + const auto v_dram_naive = + make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(length, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); const auto v_dram_transposed = transform_tensor_view(v_dram_naive, @@ -1079,15 +1084,15 @@ struct FmhaFwdDecodeKernel v_page_block_navigator, // Remove it bias_dram_window, lse_acc_dram_window, - kargs.num_splits, // Remove it - i_split_, // Remove it + kargs.num_splits, // Remove it + i_split_, // Remove it mask, position_encoding, kargs.scale_s, - variant, // Remove it - variant_params, // Remove it - block_indices, // Remove it - kv_l2p_offset, // Remove it + variant, // Remove it + variant_params, // Remove it + block_indices, // Remove it + kv_l2p_offset, // Remove it smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index daa95fb640..a419ea8ed3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -1071,20 +1071,20 @@ struct FmhaFwdSplitKVKernel { return FmhaPipeline{}(q_dram_window, k_dram_window_lengths, - // k_page_block_navigator, + // k_page_block_navigator, v_dram_window_lengths, - // v_page_block_navigator, + // v_page_block_navigator, bias_dram_window, lse_acc_dram_window, - // kargs.num_splits, - // i_split_, + // kargs.num_splits, + // i_split_, mask, position_encoding, kargs.scale_s, variant, variant_params, block_indices, - // kv_l2p_offset, + // kv_l2p_offset, smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp index a95277f620..a27fa4baa6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp @@ -11,8 +11,7 @@ namespace ck_tile { // This pipeline is qkv all located in LDS -template +template struct BlockFmhaFwdDecodePipelineQRKSVS { using Problem = remove_cvref_t; @@ -52,11 +51,13 @@ struct BlockFmhaFwdDecodePipelineQRKSVS // static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && // Problem::kPadHeadDimV == true); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; // support multiple of vector(like 8x) + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = + Problem::kPadHeadDimQ; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = + Problem::kPadHeadDimV; // support multiple of vector(like 8x) static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; static constexpr auto BiasEnum = Problem::BiasEnum; @@ -130,33 +131,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename QElementFunction, - typename KElementFunction, - typename VElementFunction, - typename BiasElementFunction, - typename LSEaccElementFunction, - typename SAccElementFunction, - typename PComputeElementFunction, - typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile const KPageBlockNavigator& k_page_block_navigator, - const KElementFunction& k_element_func, const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile const VPageBlockNavigator& v_page_block_navigator, - const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile - const LSEaccElementFunction& lse_acc_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile index_t num_splits, index_t i_split, FmhaMask mask, @@ -184,56 +169,11 @@ struct BlockFmhaFwdDecodePipelineQRKSVS kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - // Q tile in LDS - QDataType* q_lds_ptr = - static_cast(static_cast(static_cast(smem_ptr))); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - - // K tile in LDS - KDataType* k_lds_ptr = - static_cast(static_cast(static_cast(smem_ptr))); - auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - - // V tile in LDS - auto v_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - max(Policy::template GetSmemSizeQ(), - Policy::template GetSmemSizeK())), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_window = make_tile_window( - v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); - - // S tile in LDS - auto s_lds = make_tensor_view( - reinterpret_cast(reinterpret_cast(smem_ptr) + - max(Policy::template GetSmemSizeQ(), - Policy::template GetSmemSizeK())), - Policy::template MakeSLdsBlockDescriptor()); - auto s_write_lds_window = make_tile_window( - s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); - auto s_read_lds_window = - make_tile_window(s_lds, - Policy::template MakeSLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeSRegTileDistribution()); // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); - - // load Q here, will store Q into LDS to maximize throughput - auto origin_q = load_tile(q_dram_window); - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); auto s_acc = SaccBlockTileType{}; @@ -259,7 +199,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); @@ -279,8 +219,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS if(get_thread_local_1d_id() < kM0) { - store_tile(lse_acc_dram_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); + store_tile(lse_acc_dram_window_tmp, lse_acc); } } @@ -290,6 +229,25 @@ struct BlockFmhaFwdDecodePipelineQRKSVS } } + // Q tile in LDS + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); + + auto q_lds = make_tensor_view( + static_cast(smem_ptr), Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_store_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); + + async_load_tile(q_lds_store_window, q_dram_window); + + // K tile in LDS const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; // make sure the first tile is completely located in page-block (page-block size should be @@ -307,12 +265,59 @@ struct BlockFmhaFwdDecodePipelineQRKSVS return physical_seqlen_k_start_; } }(); - const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + auto k_dram_window = make_tile_window( + k_dram_block_window, Policy::template MakeKDramTileDistribution()); + + auto k_lds = make_tensor_view( + static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + auto k_lds_read_window = + make_tile_window(k_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKRegTileDistribution()); + + // S tile in LDS + auto s_lds = make_tensor_view( + reinterpret_cast(reinterpret_cast(smem_ptr) + + max(Policy::template GetSmemSizeQ(), + Policy::template GetSmemSizeK())), + Policy::template MakeSLdsBlockDescriptor()); + auto s_write_lds_window = make_tile_window( + s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); + auto s_read_lds_window = + make_tile_window(s_lds, + Policy::template MakeSLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeSRegTileDistribution()); + + // V tile in LDS + auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( + v_dram_block_window_lengths, + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto v_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + max(Policy::template GetSmemSizeQ(), + Policy::template GetSmemSizeK()) + + Policy::template GetSmemSizeS()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_write_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeVRegTileDistribution()); + + // Bias tile in LDS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -322,51 +327,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS aligned_physical_seqlen_k_start)}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( - v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); + block_sync_lds_direct_load<0>(); + auto q_tile = load_tile(q_lds_read_window); - // store Q into LDS - __builtin_amdgcn_sched_barrier(0); - auto q_lds_window_for_store = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); - store_tile(q_lds_window_for_store, origin_q); - __builtin_amdgcn_sched_barrier(0); - - // load Q from LDS - __builtin_amdgcn_sched_barrier(0); - auto q_lds_window_for_load = - make_tile_window(q_lds, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeQRegTileDistribution()); - block_sync_lds(); - auto q = load_tile(q_lds_window_for_load); - __builtin_amdgcn_sched_barrier(0); - auto q_tile = tile_elementwise_in(q_element_func, q); - - // prefetch K tile index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; - static_assert(1 <= k0_loops); - static_assert(1 <= k1_loops); + static_assert(1 == k0_loops); + static_assert(1 == k1_loops); - auto k_dram_window = make_tile_window( - k_dram_block_window, - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + async_load_tile(k_lds_write_window, k_dram_window); + // move K tile windows + i_page_block_k = + k_page_block_navigator.move_tile_window(i_page_block_k, k_dram_block_window, {kN0, 0}); - // load the first tile of the first iteration and store to LDS - auto k_block_tile = load_tile(k_dram_window); - // moving k_dram_window is an in-page-block operation, so there is - // no need to invoke k_page_block_navigator.move_tile_window() here. - move_tile_window(k_dram_window, {0, kK0}); - // ensure LDS access by Q is done before the over-writting by K - block_sync_lds(); - store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_dram_window = make_tile_window(k_dram_block_window, + Policy::template MakeKDramTileDistribution()); do { @@ -385,40 +365,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS 0); // prevent from messing up the order of global loads } - if constexpr(k0_loops > 1) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - k_block_tile = load_tile(k_dram_window); // global read i + 1 - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_window); - block_sync_lds(); - move_tile_window(k_dram_window, {0, kK0}); + async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile + // move V tile windows + i_page_block_v = + v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); - store_tile( - k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 - }); - } + block_sync_lds_direct_load(); + auto k_tile = load_tile(k_lds_read_window); - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - { // tail - block_sync_lds(); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_window); - } + gemm_0( + s_acc, + get_slice_tile( + q_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile( + k_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence{})); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -437,7 +401,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS const auto k_origin = k_page_block_navigator.to_global_window_origin( i_page_block_k, k_dram_block_window.get_window_origin()); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -455,7 +418,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -530,39 +492,30 @@ struct BlockFmhaFwdDecodePipelineQRKSVS } } - __builtin_amdgcn_sched_barrier(0); + async_load_tile(k_lds_write_window, k_dram_window); + i_page_block_k = k_page_block_navigator.move_tile_window( + i_page_block_k, k_dram_block_window, {kN0, 0}); - // load the first tile for next iteration - if(i_total_loops < num_total_loop - 1) - { - // move K tile windows - i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}); + k_dram_window = make_tile_window(k_dram_block_window, + Policy::template MakeKDramTileDistribution()); - k_dram_window = make_tile_window( - k_dram_block_window, - Policy::template MakeKDramTileDistribution()); // K DRAM tile window - - // laod the first tile of the first iteration and store to LDS - k_block_tile = load_tile(k_dram_window); - } - - __builtin_amdgcn_sched_barrier(0); - // In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1 - // Otherwise shuffle through LDS so that the tile layout is consistent with required by Gemm1 - auto s_new = [&](){ - if constexpr ( !((kNWarp==1) && (kNXdl == 32)) ){ + // Otherwise shuffle through LDS so that the tile layout is consistent with required by + // Gemm1 + auto s_new = [&]() { + if constexpr(!((kNWarp == 1) && (kNXdl == 32))) + { auto s = cast_tile(s_acc); // S{j} store_tile(s_write_lds_window, s); block_sync_lds(); return load_tile(s_read_lds_window); } - else{ + else + { return cast_tile(s_acc); // S{j} } - }(); + }(); auto m_local = block_tile_reduce( s_new, @@ -630,8 +583,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + const auto p = cast_tile(p_compute); // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); @@ -670,79 +622,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS }); }); - block_sync_lds(); - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_prefetch); - store_tile( - v_lds_window, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch - } - i_page_block_v = - v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); + block_sync_lds_direct_load(); + auto v_tile = load_tile_transpose(v_lds_read_window); - // STAGE 3, KV gemm - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&, - &i_page_block_v_ = i_page_block_v, - &v_dram_window_ = v_dram_window](auto i_k1) { - const auto v = load_tile(v_dram_window_); // load next v - block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile(v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_window); - block_sync_lds(); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v); - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v)); // store next v - } - i_page_block_v_ = v_page_block_navigator.move_tile_window( - i_page_block_v_, v_dram_window_, {0, kK1}); - }); - } - - // tail - { - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_window); - block_sync_lds(); - } - - __builtin_amdgcn_sched_barrier(0); - - // load the first tile for next iteration - if(i_total_loops < num_total_loop - 1) - { - // store the first tile for next iteration to LDS - // moving k_dram_window is an in-page-block operation, so there is - // no need to invoke k_page_block_navigator.move_tile_window() here. - move_tile_window(k_dram_window, {0, kK0}); - store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); - } } while(++i_total_loops < num_total_loop); if constexpr(kStoreLSE) @@ -777,8 +666,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS if(get_thread_local_1d_id() < kM0) { - store_tile(lse_acc_dram_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); + store_tile(lse_acc_dram_window_tmp, lse_acc); } } @@ -802,66 +690,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS }); }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile - const KPageBlockNavigator& k_page_block_navigator, - const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile - const VPageBlockNavigator& v_page_block_navigator, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile - index_t num_splits, - index_t i_split, - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - const AttentionVariant& variant, - const AttentionVariantParams& variant_params, - const BlockIndices& block_indices, - index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_lengths, - k_page_block_navigator, - identity{}, - v_dram_block_window_lengths, - v_page_block_navigator, - identity{}, - bias_dram_block_window_tmp, - identity{}, - lse_acc_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - num_splits, - i_split, - mask, - position_encoding, - scale_s, - variant, - variant_params, - block_indices, - kv_l2p_offset, - smem_ptr); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp index ea499c4e9d..7cd6c1606b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp @@ -7,6 +7,11 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" namespace ck_tile { @@ -116,6 +121,137 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy return q_lds_block_desc; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto k_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); + + return k_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS() { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 185abccd3f..12ee11e4bc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -183,7 +183,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; + WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16>>; #else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +// template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -57,6 +59,8 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +// template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity