diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 38830ee6fe..6cfa862a1a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -23,6 +23,44 @@ namespace ck_tile { +namespace detail { + +// A helper struct for detecting n0loop +template +struct has_n0loop_flag : std::false_type +{ +}; + +template +struct has_n0loop_flag< + T, + std::enable_if_t && T::kUseN0Loop>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag::value; + +// A helper struct for detecting ignore_fast_exp2 flag +template +struct has_ignore_fast_exp2_flag : std::false_type +{ +}; + +template +struct has_ignore_fast_exp2_flag< + T, + std::enable_if_t && + T::kIgnoreFastExp2>> : std::true_type +{ +}; + +template +static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag::value; + +}; // namespace detail + template struct FmhaFwdKernel { @@ -402,7 +440,9 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), + detail::ignore_fast_exp2_v + ? scale_s + : static_cast(scale_s * ck_tile::log2e_v<>), #else scale_s, #endif @@ -741,7 +781,9 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), + detail::ignore_fast_exp2_v + ? scale_s + : static_cast(scale_s * ck_tile::log2e_v<>), #else scale_s, #endif @@ -1303,10 +1345,21 @@ struct FmhaFwdKernel number<1>{}); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + + if constexpr(detail::is_n0loop_pipeline_v) + { + return pad_tensor_view(k_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } }(); const auto v_dram = [&]() { if constexpr(std::is_same_v) @@ -1359,10 +1412,22 @@ struct FmhaFwdKernel }(), {i_m0, 0}); - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); + auto k_dram_window = [&]() { + if constexpr(detail::is_n0loop_pipeline_v) + { + return make_tile_window(k_dram, + make_tuple(number{}, + number{}), + {0, 0}); + } + else + { + return make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + }(); auto v_dram_window = make_tile_window( v_dram, @@ -1508,7 +1573,10 @@ struct FmhaFwdKernel *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); #if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; + if constexpr(!detail::ignore_fast_exp2_v) + { + slope *= ck_tile::log2e_v<>; + } #endif if constexpr(kHasMask) { @@ -2247,7 +2315,10 @@ struct FmhaFwdKernel *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); #if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; + if constexpr(!detail::ignore_fast_exp2_v) + { + slope *= ck_tile::log2e_v<>; + } #endif if constexpr(kHasMask) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 8114bb96c4..78a7f18342 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -20,7 +20,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using KDataType = remove_cvref_t; using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; using RandValOutputDataType = remove_cvref_t; using LSEDataType = remove_cvref_t; @@ -35,11 +35,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static constexpr bool kQLoadOnce = true; static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kUseN0Loop = true; + static constexpr bool kIgnoreFastExp2 = true; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; @@ -63,19 +65,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); + static constexpr index_t kAlignmentV = + Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentO = kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - static constexpr index_t kAlignmentRandVal = - kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) @@ -161,6 +157,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch ignore = q_element_func; ignore = k_element_func; + // xformers path does not require the pipeline to output random values for host + // verification, since a separate kernel is used to generate random values + ignore = randval_dram_block_window_tmp; + static_assert( std::is_same_v> && std::is_same_v> && @@ -168,64 +168,77 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kK1 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kSubQKHeaddim == + KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; - - constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; - static_assert(2 <= k0_loops); - static_assert(2 <= k1_loops); - constexpr bool kPreloadWholeNextIterationK = - Policy::template IsPreloadWholeNextIterationK(); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); - constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); - constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers(); - constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - static_assert(NumKLdsBuffers >= 2); + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQRegTileDistribution()); + auto q_tile = load_tile(q_dram_window); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); - auto k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); using k_tile_type = decltype(load_tile(k_dram_window)); - auto k_tiles = [&]() { - if constexpr(kPreloadWholeNextIterationK) - return statically_indexed_array{}; - else - return statically_indexed_array{}; - }(); + constexpr index_t NumPrefetchK = 2; - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); + static_assert(k1_loops >= NumPrefetchK, "Check failed!"); - auto q_tile = load_tile(q_dram_window); + // only prefetch two k tiles to save vgprs consumption + statically_indexed_array k_tiles; + + static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); __builtin_amdgcn_sched_barrier(0); @@ -236,612 +249,341 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_write_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_windows; + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + using k_lds_read_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); - static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_windows[i_buf] = get_slice_tile( - k_lds_window, sequence{}, sequence<(i_buf + 1) * kN0, kK0>{}); + statically_indexed_array k_lds_write_windows; + statically_indexed_array k_lds_read_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_write_windows[i_buf] = + get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); // V tile in LDS auto v_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetExclusiveKLdsBytes()), + reinterpret_cast(smem_ptr), Policy::template MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); - using v_tile_type = decltype(load_tile(v_dram_window)); - - statically_indexed_array v_tiles; - using v_lds_window_type = decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array v_lds_windows; + statically_indexed_array v_lds_windows; - static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { v_lds_windows[i_buf] = get_slice_tile( v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); }); - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {0, seqlen_k_start}, + Policy::template MakeVDramTileDistribution()); - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(s_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK) - { - if(num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse, -numeric::infinity()); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, + Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + // assuming no random values need be saved, this is try when this pipeline is called from + // xformers, since we have a separate kernel to generated randomm values + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to pass a null_randval_dram and tile window to the BlockDropout operator to + // make it works + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); q_tile = tile_elementwise_in(q_element_func, q_tile); - index_t i_total_loops = 0; + auto seqlen_k_curr = seqlen_k_start; + + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; do { - if constexpr(kPreloadWholeNextIterationK) - { - if(i_total_loops == 0) // executed by fist iteration + // STAGE 1, Gemm_0 ( S = Q@K ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_write_windows[number{}], + k_tiles[number{}]); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_k1 < k1_loops - NumPrefetchK) { - if(num_total_loop > 1) // there are multiple iterations - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); - - if constexpr(i_k0 == 0) - clear_tile(s_acc); - - block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); - - store_tile( - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - - // prefetch all k_tiles for next iteration - static_for<0, k0_loops, 1>{}([&](auto i_k0) { - k_tiles[number{}] = load_tile(k_dram_window); - - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - }); - - move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); - - block_sync_lds(); - // execute last unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - } - else // there is only single iteration - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); - - if constexpr(i_k0 == 0) - clear_tile(s_acc); - - block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); - - store_tile( - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - - // move_tile_window(k_dram_window, {0, -k0_loops * kK0}); - } + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); } - else // executed by intermediate and last iteration + else { - if(i_total_loops < num_total_loop - 1) // intermediate iteration - { - store_tile(k_lds_windows[I0], - tile_elementwise_in(k_element_func, k_tiles[I0])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - clear_tile(s_acc); - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), - k_lds_windows[I0]); - - store_tile(k_lds_windows[I1], - tile_elementwise_in(k_element_func, k_tiles[I1])); - - move_tile_window(k_dram_window, {kN0, 0}); - - // prefetch first k_tile for next iteration - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - - k_tiles[I1] = load_tile(k_dram_window); - if constexpr(1 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, kK0>{}, sequence{}), - k_lds_windows[I1]); - - // during the gemm-loop, also prefetch other k_tiles for next iteration - static_for<2, k0_loops, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], - k_tiles[number{}]); - - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); - - move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); - } - else // last iteration - { - store_tile(k_lds_windows[I0], - tile_elementwise_in(k_element_func, k_tiles[I0])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - clear_tile(s_acc); - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), - k_lds_windows[I0]); - - static_for<1, k0_loops, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); - }; + // load v_tiles used in current iteration + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); }; - } - else // only preload one unroll of K for next iteration - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); - if constexpr(i_k0 == 0) - clear_tile(s_acc); - if constexpr(i_k0 < k0_loops - 1) - k_tiles[I0] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); - - block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); - - store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + __builtin_amdgcn_sched_barrier(0x00000001); block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - }; + + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + + auto tmp_tile = cast_tile(sacc_tile); + + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); __builtin_amdgcn_sched_barrier(0); const auto bias_tile = load_tile(bias_dram_window); // load bias tile - static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif + [&](auto& x, const auto y) { + x += type_convert(bias_element_func(y)); }, - s_acc, + pcomp_tile, bias_tile); } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); + pcomp_tile(i_j_idx) *= scale_s; + position_encoding.update(pcomp_tile(i_j_idx), row, col); }); }); } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); - } - } - - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - } -#else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); -#endif - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - }(); -#else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); -#endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; - }); - }); - - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeK(); - dropout.template Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); - } - - __builtin_amdgcn_sched_barrier(0x7f); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - - store_tile( - v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch - } - - __builtin_amdgcn_sched_barrier(0); - - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); - - if constexpr(!kPreloadWholeNextIterationK) - { - if(i_total_loops < num_total_loop - 1) - { - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - }; - - __builtin_amdgcn_sched_barrier(0); - } - - // STAGE 3, KV gemm - if constexpr(k1_loops > 1) - { - if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2 - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - v_tiles[I0] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[I0])); - } - - move_tile_window(v_dram_window, {0, kK1}); - }); - } - else // NumVLdsBuffers == 3 or 2 - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < k1_loops - NumPrefetchV) - v_tiles[number{}] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile( - v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); - } - - if constexpr(i_k1 < k1_loops - NumPrefetchV) - move_tile_window(v_dram_window, {0, kK1}); + set_tile_if(pcomp_tile, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); }); } } - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); + __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) + using v_shuffled_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution())); + + v_shuffled_tile_type v_shuffled_tile; + + shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { - __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); }; - } while(++i_total_loops < num_total_loop); + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for{}([&](auto i_k1) { + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + seqlen_k_curr += kN0; + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + // k1_loops >= 2 required + shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); + + store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < NumPrefetchK) + { + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 2) + { + __builtin_amdgcn_sched_barrier(0x00000001); + + shuffle_tile(v_shuffled_tile, v_tiles[number{}]); + store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], + v_shuffled_tile); + + __builtin_amdgcn_sched_barrier(0x00000001); + }; + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); // store lse if constexpr(kStoreLSE) @@ -851,19 +593,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } -#else - lse(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); }); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); @@ -874,17 +604,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; }); }); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 3f015a1c1a..e5a45afeea 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -4,17 +4,18 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp" namespace ck_tile { struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy - : BlockFmhaPipelineQXKSVSCustomPolicy { - static constexpr index_t NumPrefetchV = 2; + static constexpr bool QLoadOnce = true; // needed by the kernel + static constexpr bool AsyncCopy = false; // needed by the kernel template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK() @@ -23,30 +24,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy }; template - CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers() + CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers() { - return 2; + return 4; } - template - CK_TILE_DEVICE static constexpr auto GetNumPrefetchV() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t kN0 = BlockFmhaShape::kN0; - constexpr index_t kK1 = BlockFmhaShape::kK1; - - constexpr index_t k1_loops = kN0 / kK1; - - return min(NumPrefetchV, k1_loops); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers() - { - return 2; - }; - template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -57,49 +39,268 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeCBlockTile() + .get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() { - using KDataType = remove_cvref_t; - return 8 / sizeof(KDataType); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); + + return kVecLoad; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKVector = GetAlignmentK(); + + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + return kKPerBlock * kNPerBlock; + } + else + { + static_assert(kKVector % kKPack == 0); + + return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + }; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + + return N0 * (N1 * kKPerBlock + kKPack); + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + return max(GetKSingleSmemElementSpaceSize(), + GetVSingleSmemElementSpaceSize()); + }; + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { - constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - static_assert(kKVector % kKPack == 0); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + using KDataType = remove_cvref_t; - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr index_t DataTypeSize = sizeof(KDataType); - return k_lds_block_desc; + // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + constexpr auto k_lds_block_desc_k0_nldslayer_n_k1 = transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{})); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_k0_nldslayer_n_k1, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<1, 3>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr index_t KSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + }; } template @@ -108,8 +309,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); @@ -136,44 +337,45 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() { - using VDataType = remove_cvref_t; + constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t Banks = get_n_lds_banks(); - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t VSingleSmemElementSpaceSize = - (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; + + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); + + static_assert(kKPack >= K2, "Check failed!"); + + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - number{}, - number{}, + make_tuple(number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, number<1>{}), number{}, number<1>{}); constexpr auto v_lds_block_desc = transform_tensor_descriptor( v_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple( - number{}, number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return v_lds_block_desc; @@ -182,70 +384,55 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() { - using VLayout = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - if constexpr(std::is_same_v) - { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; // P + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - constexpr index_t K3 = ElemPerThread / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - } - else - { - constexpr index_t K1 = GetAlignmentV(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - static_assert(N0 != 0); + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } + static_assert(ElemPerThread % N1 == 0); + + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + static_assert(ElemPerThread % N1 == 0); + + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); } template @@ -257,113 +444,163 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy typename Problem::SaccDataType, Problem::kNumGemm0Warps * get_warp_size(), TileGemmShape, + Problem::BlockFmhaShape::kK1, + Problem::BlockFmhaShape::kQKHeaddim>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(get_warp_size() == 64 && - std::is_same_v && - std::is_same_v && - std::is_same_v) + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v)&&std:: + is_same_v) { - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); + +#ifdef __gfx950__ + static_assert((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16), + "Not supported WarpGemm sizes!"); +#else + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); +#endif - // TODO: hard coded here. Otherwise, it produces incorrect results - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } - else - { - constexpr bool SwizzleA = - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true, // TransposeC - SwizzleA>{}; + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); } }(); + using WarpGemm = remove_cvref_t; + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy; + WarpGemm>; if constexpr(1 < Problem::kNumGemm0Warps) - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2PrefetchK{}; else return BlockGemmARegBSmemCRegOneWarpV1{}; } - // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first - // k_lds bufffer template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - constexpr index_t single_k_lds_buffer_size = - GetSmemSizeK() / GetNumKLdsBuffers(); - constexpr index_t single_v_lds_buffer_size = - GetSmemSizeV() / GetNumVLdsBuffers(); + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; - if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size) - return 0; + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v)&&std:: + is_same_v) + { + constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}); + +#ifdef __gfx950__ + static_assert((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16), + "Not supported WarpGemm sizes!"); +#else + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); +#endif + + if constexpr((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16)) + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>{}; + else + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); + } + }(); + + using WarpGemm = remove_cvref_t; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + if constexpr(1 < Problem::kNumGemm1Warps) + return BlockGemmARegBSmemCRegV2PrefetchN{}; else - return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); - }; - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; - constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); - constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); - - constexpr index_t last_v_lds_buffer_offset = - MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * - ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType); - - constexpr index_t first_k_lds_buffer_size = - MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * - sizeof(typename Problem::KDataType); - - return GetExclusiveKLdsBytes() + last_v_lds_buffer_offset < - first_k_lds_buffer_size; - }; - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() - { - return MakeKLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::KDataType); + return BlockGemmARegBSmemCRegOneWarpV1{}; } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { - return MakeVLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::VDataType); - } + constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); + + return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * + max(sizeof(typename Problem::KDataType), sizeof(typename Problem::VDataType)); + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + return 0; + }; template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // assume V can reuse the other shared memory by K except the first - // assume Dropout can reuse the shared memory by V - return GetExclusiveKLdsBytes() + - max(GetSmemSizeK() - GetExclusiveKLdsBytes(), - max(GetSmemSizeV(), GetSmemSizeDropout(0))); + return GetSmemSizeKV() + GetSmemSizeDropout(); } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp new file mode 100644 index 0000000000..3f21c44207 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2PrefetchK +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; +#endif + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + // hot loop: + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + statically_indexed_array b_warp_tensors; + + b_warp_windows(nIter)(I0) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(I0), + {nIter * NPerBlockPerIter, 0 * KPerBlockPerIter}); + b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0)); + + __builtin_amdgcn_sched_barrier(0); + + b_warp_windows(nIter)(I1) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(I1), + {nIter * NPerBlockPerIter, 1 * KPerBlockPerIter}); + b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1)); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // warp GEMM + auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + + static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { + // read B warp tensor from B Block window + if constexpr(kIter < KIterPerWarp - 1) + { + b_warp_windows(nIter)(number{}) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(number{}), + {nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter}); + b_warp_tensors[number{}] = + load_tile(b_warp_windows(nIter)(number{})); + }; + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode(); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode(); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp new file mode 100644 index 0000000000..3ad4037926 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2PrefetchN +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + statically_indexed_array, + NIterPerWarp> + b_warp_tensors; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {0 * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter)); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(number{})(kIter) = + load_tile(b_warp_windows(number{})(kIter)); + }); + }; + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); + }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile