From eb598a9d1ed8dab09137c9c7fa6539cd720d4727 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 18 Dec 2025 10:50:30 +0000 Subject: [PATCH] Add qr_ks_vs_whole_k_prefetch_trload pipeline --- include/ck_tile/ops/fmha.hpp | 1 + ..._ks_vs_whole_k_prefetch_default_policy.hpp | 23 +- ...eline_qr_ks_vs_whole_k_prefetch_trload.hpp | 835 ++++++++++++++++++ ...m_areg_bsmem_trload_creg_v2_prefetch_n.hpp | 243 +++++ 4 files changed, 1097 insertions(+), 5 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 5b87a821c9..37267efce3 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -54,6 +54,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 1fcf823abf..d805075980 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp" namespace ck_tile { @@ -235,11 +236,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); - return N0 * (N1 * kKPerBlock + kKPack); + return N0 * (N1 * kKPerBlock + kKPack); + } + else + { + return kNPerBlock * kKPerBlock; + }; }; template @@ -958,7 +966,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy WarpGemm>; if constexpr(1 < Problem::kNumGemm1Warps) - return BlockGemmARegBSmemCRegV2PrefetchN{}; + { + if constexpr(!Problem::kUseTrLoad) + return BlockGemmARegBSmemCRegV2PrefetchN{}; + else + return BlockGemmARegBSmemTrLoadCRegV2PrefetchN{}; + } else return BlockGemmARegBSmemCRegOneWarpV1{}; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp new file mode 100644 index 0000000000..4ebfe19bd7 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -0,0 +1,835 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr bool kUseN0Loop = true; + static constexpr bool kIgnoreFastExp2 = true; + static constexpr bool kIsNaiveHDimLoad = true; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kN0Sub = BlockFmhaShape::kN0Sub; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + + static_assert(Problem::kUseTrLoad == true, "Check failed!"); + + static constexpr bool kUseTrLoad = true; + + static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!"); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& /* unused */, + const AttentionVariantParams& /* unused */, + const BlockIndices& /* unused */, + void* smem_ptr, + DropoutType& dropout) const + { + // xformers path does not require the pipeline to output random values for host + // verification, since a separate kernel is used to generate random values + ignore = randval_dram_block_window_tmp; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr index_t n0_loops = kN0 / kN0Sub; + constexpr index_t k1_loops = kN0 / kK1; + + // usually kN0 is 128, kN0Sub/kK1 is 32/16 + static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); + static_assert(n0_loops >= NumPrefetchV, "Check failed!"); + static_assert(k1_loops >= NumPrefetchV, "Check failed!"); + + constexpr bool kPreloadWholeNextIterationK = + Policy::template IsPreloadWholeNextIterationK(); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramSingleRepMTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + + auto q_dram_tile = load_tile(q_dram_window); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + // only prefetch two k tiles to save vgprs consumption + auto k_tiles = [&]() { + if constexpr(kPreloadWholeNextIterationK) + return statically_indexed_array{}; + else + return statically_indexed_array{}; + }(); + + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(smem_ptr); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_write_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + auto q_lds_read_window = + make_tile_window(q_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQRegTileDistribution(), + partition_index); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + using k_lds_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + }); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); + + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, + Policy::template MakeBiasDramTileDistribution()); + + // assuming no random values need be saved, this is true when the pipeline is called from + // xformers, since we have a separate kernel to generated randomm values + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to pass a null_randval_dram and tile window to the BlockDropout operator to + // make it works + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + store_tile(q_lds_write_window, q_dram_tile, partition_index); + + clear_tile(o_acc); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + auto q_tile = load_tile(q_lds_read_window); + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + set_tile(m, -numeric::infinity()); + clear_tile(l); + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + auto seqlen_k_curr = seqlen_k_start; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile + __builtin_amdgcn_s_barrier(); + + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + do + { + // STAGE 1, Gemm_0 ( S = Q@K ) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 + { + if(seqlen_k_curr == seqlen_k_start) // at first iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 < NumPrefetchV) + { + v_tiles[i_n0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + // prefetch all k_tiles for next iteration + static_for<0, n0_loops, 1>{}([&](auto ii_n0) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + else // the iteration is also the last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 < NumPrefetchV) + { + v_tiles[i_n0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + }; + } + else // at intermediate and last iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < NumPrefetchV) + { + v_tiles[i_n0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + else // last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < NumPrefetchV) + { + v_tiles[i_n0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + }; + } + } + else // only preload one unroll of K for next iteration, used when kM0=128 + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[I0]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 < NumPrefetchV) + { + v_tiles[i_n0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + + __builtin_amdgcn_sched_barrier(0x000000001); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + + tile_elementwise_inout( + [&](auto& x, const auto y) { + x += type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + pcomp_tile(i_j_idx) *= scale_s; + position_encoding.update(pcomp_tile(i_j_idx), row, col); + }); + }); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + move_tile_window(bias_dram_window, {0, kN0}); + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(pcomp_tile, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + seqlen_k_curr += kN0; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[I0]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]), + partition_index); + }; + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp new file mode 100644 index 0000000000..f92216ef28 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + constexpr auto b_warp_dstr_encode = + typename InputTileDistributionTraits::TransposedDstrEncode{}; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN}, + make_static_tile_distribution(b_warp_dstr_encode)); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0))); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {kIter * KPerBlockPerIter, 0 * NPerBlockPerIter}); + b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter)); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile_transpose(b_warp_windows(number{})(kIter)); + }; + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile