From e54cb5a713036a3d019f66188ca7c62265fdeb63 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 6 Oct 2025 13:02:38 +0000 Subject: [PATCH] intial commit --- .../block/block_attention_bias_enum.hpp | 37 + .../unified_attention/block/block_dropout.hpp | 654 +++++++++ .../unified_attention/block/block_masking.hpp | 642 +++++++++ .../block/block_position_encoding.hpp | 205 +++ .../block/block_rotary_embedding.hpp | 108 ++ .../block/page_block_navigator.hpp | 358 +++++ .../ops/unified_attention/block/variants.hpp | 302 ++++ .../kernel/fmha_fwd_v3_kernel.hpp | 450 ++++++ .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 1258 +++++++++++++++++ ...ck_fmha_fwd_v3_pipeline_default_policy.hpp | 603 ++++++++ .../pipeline/block_fmha_pipeline_enum.hpp | 42 + .../pipeline/block_fmha_pipeline_problem.hpp | 60 + 12 files changed, 4719 insertions(+) create mode 100644 include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_dropout.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_masking.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/variants.hpp create mode 100644 include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp diff --git a/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp new file mode 100644 index 0000000000..e5be21e048 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockAttentionBiasEnum +{ + NO_BIAS = 0, + ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale) + ALIBI = 2, // bias computed with position encoding, applied after scale +}; + +template +struct BlockAttentionBiasEnumToStr; + +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "bias"; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "alibi"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_dropout.hpp b/include/ck_tile/ops/unified_attention/block/block_dropout.hpp new file mode 100644 index 0000000000..8abdd54cd9 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_dropout.hpp @@ -0,0 +1,654 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and +// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random +// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host +// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp). +// +// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of +// random numbers (ph_subsequence). +// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and +// ph_offset). +// This means that subsequences are non-overlapping, reproducible and independent of mask or window. +// +// There are 3 modes (all produce the same results): +// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates +// the entire 32x32 tile (64 * 16 = 32 * 32). +// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4 +// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock > +// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions +// are needed for generating a 32x32 tile. +// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2 +// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp * +// WG::kM one warp can generate two 16x16 tiles. + +namespace detail { +// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values +constexpr index_t philox_per_tile = 64; +} // namespace detail + +struct NullBlockDropout +{ + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + +struct BlockDropout +{ + CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_, + bool is_store_randval_) + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_), + is_store_randval(is_store_randval_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t NIterPerWarp = 1; + + // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution, + // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); + + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + + const auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) + { + // Generate the whole 32x32 tile at once (each tile consists of random numbers taken + // from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + else + { + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } + } + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // Transpose randval using LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + const auto randval = load_tile(randval_lds_read_window); + block_sync_lds(); + return randval; + }; + + if(is_store_randval) + { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + const auto randval = generate_randval(i_m0, i_n0); + // save to Global + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + }); + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + }); + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + } + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = + tile_distributed_index()>{}; + constexpr auto p_idx1 = + tile_distributed_index(), + idx1.impl_.template at<2>()>{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + }); + }); + } + + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be +// replaced with NullBlockDropout. This requires changes in xformers and other libs. +template +struct BlockDropoutBwd; + +template +struct BlockDropoutBwd +{ + static constexpr bool IsDropout = false; + static constexpr bool IsStoreRandval = IsStoreRandval_; + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + +template +struct BlockDropoutBwd +{ + static constexpr bool IsDropout = true; + static constexpr bool IsStoreRandval = IsStoreRandval_; + + CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_) + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; + static_assert( + std::is_same_v, + typename WG::CWarpDstrEncoding>); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) + { + // Generate the whole 32x32 tile at once (each tile consists of random numbers + // taken from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + else + { + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } + } + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + return randval_dist_generated; + }; + + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities, negative sign is used to + // distinguish such values ​​later in bwd pipeline. + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + constexpr auto p_idx0 = + tile_distributed_index(), + idx0.impl_.template at<1>(), + idx0.impl_.template at<2>()>{}; + constexpr auto p_idx1 = tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] + : -p_compute[p_idx]; + }); + }); + // save to Global + if constexpr(IsStoreRandval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {kMPerStep, 0}); + } + }); + if constexpr(IsStoreRandval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); + } + }); + if constexpr(IsStoreRandval) + { + move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); + } + } + + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp new file mode 100644 index 0000000000..2c45945fac --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -0,0 +1,642 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +enum struct GenericAttentionMaskEnum +{ + NO_MASK = 0, + + // below enum could be causal, or sliding window + MASK_FROM_TOP_LEFT = 1, + MASK_FROM_BOTTOM_RIGHT = 2, + + // this enum maybe not used by xformer/FA, since it's hard to + // specify left/right window for varlen case. put it here for + // debug purpose + MASK_GENERIC, +}; + +// clang-format off +/* generic Attention Mask Coordinate + use x(horizontal axis), y(vertical axis) to describe mask. + top-left corner is origin + + x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask) + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + l=7,-1/r=0(tl) l=7,-1/r=0(br) + + x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2 + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 + * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1 + l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl) + l=4/r=0(br) l=4/r=2(br) l=4/r=4(br) + + x=4/y=-1 x=6/y=-1 x=8/y=-1 + * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 + * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 + * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1 + * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1 + + x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r) + * * * * * * * * 1 * * * * * * * + * * * * * * * * 1 1 * * 1 * * * + * * * * * * * * 1 1 1 * 1 1 * * + 1 * * * * * * * 1 1 1 1 1 1 1 * + 1 1 * * * * * * 1 1 1 1 1 1 1 1 + + Validations: + x + y > 1 (x + y >= 2) + + Note: + y = seq_q, x = 1 -> top-left + y = seq_q, x = seq_k - seq_q + 1 -> bottom-right + y < seq_q, x < seq_k -> local-attn + y = seq_q, x = seq_k -> no mask + +*/ +namespace impl { + template struct MaskName; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mc"; }; + template<> struct MaskName { static constexpr const char * name = "mg"; }; +} +// clang-format on + +template +struct GenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, + // else only upper-right could have mask + + static constexpr const char* name = impl::MaskName::name; + + CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) + : GenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = min(i_y + x, x_total); + + if constexpr(IsLocal) + { + return i_x < x_start || i_x >= x_end; + } + else + { + return i_x >= x_end || i_y >= y_total; + } + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const + { + if constexpr(!IsMasking) + { + // TODO: no need to check begin + return (i_tile_left + TileWidth) > x_total; + } + else + { + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = + i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +// clang-format off +namespace impl { + template struct SimplifiedMaskName; + template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version only have 2 variation: masking and non-masking +// This is more friendly to codegen (e.g. need generate less kernel) +// ... with the trade-off that may have more instruction in causal mode +template +struct SimplifiedGenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + template + CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t split_start = x_per_split * i_split; + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); + + return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), + ck_tile::min(origin_end, split_end)); + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + return i_x >= x_total; + } + else + { + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = min(i_y + x, x_total); // need min in case x is padded + + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + // TODO: no need to check begin + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + + bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad + bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad + // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +// clang-format off +namespace impl { + template struct SimplifiedRatioMaskName; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version is used for cases that the step length of y-direction changes greater than one. It +// means that the mask is not a regular triangular matrix. + +// clang-format off +/* y_ratio is used to describe the step length of y-direction changes + in certain performance optimization scenarios like merging seqlen + and qk_head_ratio, for example: + + x=1/y=6/y_ratio=2(top-left) + 1 * * * * * * * + 1 * * * * * * * + 1 1 * * * * * * + 1 1 * * * * * * + 1 1 1 * * * * * + 1 1 1 * * * * * + +*/ +// clang-format on +template +struct SimplifiedRatioAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedRatioMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{}) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask( + index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_) + : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast(y_ratio_mdiv_.get()), + /*x_=*/x_, + /*y_total_=*/y_total_, + /*x_total_=*/x_total_, + /*y_real_=*/y_real_, + /*y_ratio_=*/static_cast(y_ratio_mdiv_.get()), + /*y_ratio_mdiv_=*/y_ratio_mdiv_) + + { + } + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask(index_t y_, + index_t x_, + index_t y_total_, + index_t x_total_, + index_t y_real_, + index_t y_ratio_, + mdiv y_ratio_mdiv_) + : y(y_), + x(x_), + y_total(y_total_), + x_total(x_total_), + y_real(y_real_), + y_ratio(y_ratio_), + y_ratio_mdiv(y_ratio_mdiv_) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = -y_real + + static_cast(y_ratio_mdiv.div(static_cast(i_y))) + + 1; + + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + uint32_t y_offset = i_y + YTile - 1; + index_t tmp = min(static_cast(y_ratio_mdiv.div(y_offset)) + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max((-x + i_x + 1) * y_ratio, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + index_t x_tmp = static_cast(y_ratio_mdiv.div(static_cast(i_y))); + index_t x_start = -y_real + x_tmp + 1; + index_t x_end = min(x_tmp + x, + x_total); // need min in case x is padded + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + uint32_t y_tmp = static_cast(i_y); + bool top_right_edge = i_x_end > min(static_cast(y_ratio_mdiv.div(y_tmp)) + x, + x_total); // consider right pad + bool bottom_left_edge = + i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; + // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y + index_t y_real; + index_t y_ratio; + mdiv y_ratio_mdiv; +}; + +// TODO: prefer use this function in host code +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + // TODO: below should all use sgpr arithmetic + index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1; + index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1; + + left_size = left_size < 0 ? left_size_tmp : left_size; + right_size = right_size < 0 ? right_size_tmp : right_size; + + index_t x_tmp = is_top_left ? 0 : x_total - y_total; + index_t y_tmp = is_top_left ? 0 : y_total - x_total; + + index_t x = 1 + right_size + x_tmp; + index_t y = 1 + left_size + y_tmp; + + return ck_tile::make_tuple(y, x, y_total, x_total); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp new file mode 100644 index 0000000000..703ec0967a --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include +#include + +namespace ck_tile { + +enum struct PositionEncodingEnum +{ + NO = 0, + ALIBI = 1, +}; + +/* +VERTICAL: + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + +TOP_LEFT(but negative): + [0] 1 2 3 4 5 + 1 [0] 1 2 3 4 + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + +FROM_BOTTOM_RIGHT(but negative): + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + 4 3 2 1 [0] 1 + 5 4 3 2 1 [0] +*/ + +enum struct AlibiMode +{ + VERTICAL = 0, + FROM_TOP_LEFT = 1, // keep sync with mask enum + FROM_BOTTOM_RIGHT = 2, +}; + +template +struct Alibi +{ + static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32, + "for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"); + + // RowMajor here means if pixel within the same thread are along the row, or col + // this may impact the performance of update(), while the result are the same. + // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false + CK_TILE_HOST_DEVICE Alibi(DataType slope_, + index_t y_total_, + index_t x_total_, + AlibiMode mode_ = AlibiMode::VERTICAL) + { + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_; + + shift_left_up = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + }(); + shift_right_down = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + }(); + mode = mode_; + } + + CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); } + + CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) + { + if constexpr(LogMaxSadOprndSize <= 16) + { + return sad_u16( + static_cast(x), static_cast(y), static_cast(acc)); + } + + return sad_u32(x, y, acc); + } + + CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) + { + if constexpr(RowMajor) + { + // at least 3 instructions per row + index_t current_zero_point = + mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down; + + // for every threads, most of the pixels are along the row, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(col_idx + shift_left_up), + 0)); + pixel += slope * position; + } + else + { + // at least 3 instructions per col; + index_t current_zero_point = mode == AlibiMode::VERTICAL + ? row_idx + col_idx + shift_right_down + : col_idx + shift_right_down; + + // for every threads, most of the pixels are along the col, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(row_idx + shift_left_up), + 0)); + pixel += slope * position; + } + } + + DataType slope; // float? + index_t shift_left_up; // always possitive + index_t shift_right_down; // always possitive + AlibiMode mode; +}; + +template +struct EmptyPositionEncoding +{ + CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/) + { + } +}; + +// +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +template +CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, + index_t window_left_size, + index_t window_right_size, + index_t y_total, + index_t x_total, + GenericAttentionMaskEnum mask_enum) +{ + // assume mask_enum will never be NO_MASK, since if we do not have mask, it's + // totally OK to use constexpr + bool is_causal = window_left_size < 0 && window_right_size == 0; + AlibiMode alibi_mode = + is_causal ? AlibiMode::VERTICAL + : static_cast(mask_enum) /*either top-left or bottom-right*/; + return Alibi{slope, y_total, x_total, alibi_mode}; +} + +// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +// Do we need a device version? +template +CK_TILE_HOST std::vector get_alibi_slopes(ck_tile::index_t nheads) +{ + auto get_slopes_power_of_2 = [](ck_tile::index_t n) { + float start = std::powf( + static_cast(2), + -std::powf(static_cast(2), -static_cast((integer_log2_floor(n) - 3)))); + + std::vector rtn; + for(auto i = 0; i < n; i++) + { + rtn.push_back(static_cast(start * std::powf(start, i))); + } + return rtn; + }; + if(is_power_of_two_integer(nheads)) + { + // power of 2 calculation + return get_slopes_power_of_2(nheads); + } + else + { + ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads); + auto v0 = get_slopes_power_of_2(closest_power_of_2); + auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2); + auto v1_sliced = [&](auto vec, ck_tile::index_t rem) { + std::vector sliced; + for(ck_tile::index_t i = 0; i < static_cast(vec.size()); i++) + { + if(i % 2 == 0) + sliced.push_back(vec[i]); + } + std::vector sliced_2(sliced.begin(), sliced.begin() + rem); + return sliced_2; + }(v1, nheads - closest_power_of_2); + v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end()); + return v0; + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp b/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp new file mode 100644 index 0000000000..5173279299 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class RotaryEmbeddingEnum +{ + NONE = 0, + INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc + HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc +}; + +template +struct RotaryEmbeddingEnumToStr; + +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "inter"; +}; +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "half"; +}; + +template +struct BlockRotaryEmbedding +{ + template + CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile, + OtherDramBlockWindow other_window, + RotaryCosDramBlockWindow rotary_cos_window, + RotarySinDramBlockWindow rotary_sin_window, + index_t rotary_dim, + index_t thread_end) + { + using DataType = typename remove_cvref_t::DataType; + + if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) + { + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + if(thread_end <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(tile.thread_buf_[idx]); + const auto right = type_convert(tile.thread_buf_[idx + 1]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + tile.thread_buf_[idx] = type_convert(left * cos - right * sin); + tile.thread_buf_[idx + 1] = type_convert(right * cos + left * sin); + }); + } + } + else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + if(thread_end <= rotary_dim) + { + const bool is_left = (thread_end <= (rotary_dim / 2)); + + move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto other_tile = load_tile(other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(tile.thread_buf_[idx]); + const auto other = type_convert(other_tile.thread_buf_[idx]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx]); + + tile.thread_buf_[idx] = + type_convert(curr * cos + other * (is_left ? -sin : sin)); + }); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp b/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp new file mode 100644 index 0000000000..f1e6101d1d --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp @@ -0,0 +1,358 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" + +namespace ck_tile { + +// assume that we have only 1 page-block/tensor view +template +struct TrivialPageBlockNavigator +{ + using DataType = typename TensorView::DataType; + using WindowOrigin = multi_index<2>; + + CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_) + : tensor_view(tensor_view_) + { + } + + template + CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin) const + { + return make_tuple(/*block_index=*/0, + ck_tile::make_tile_window(tensor_view, window_lengths, window_origin)); + } + + template + CK_TILE_HOST_DEVICE constexpr auto + make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin, + const TileDistribution& tile_distribution) const + { + return make_tuple( + /*block_index=*/0, + ck_tile::make_tile_window( + tensor_view, window_lengths, window_origin, tile_distribution)); + } + + template + CK_TILE_HOST_DEVICE static index_t + move_tile_window(index_t /*block_index*/, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) + { + ck_tile::move_tile_window(tile_window, step); + + return /*block_index=*/0; + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t /*block_index*/, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step, + index_t /*id*/) const + { + + ck_tile::move_tile_window(tile_window, step); + return 0; + } + + template + CK_TILE_HOST_DEVICE index_t + prefetch_table_id(index_t /*block_index*/, + TileWindow /*tile_window*/, + const typename remove_cvref_t::BottomTensorIndex& /*step*/) const + { + return -1; + } + + CK_TILE_HOST_DEVICE static constexpr WindowOrigin + to_local_window_origin(const WindowOrigin& global_window_origin) + { + return global_window_origin; + } + + CK_TILE_HOST_DEVICE static constexpr WindowOrigin + to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin) + { + return local_window_origin; + } + + private: + TensorView tensor_view; +}; + +// default page-block navigator, assume that tensor view size is same as page-block size or smaller +// if tile window on last page-block +template +struct PageBlockNavigator +{ + using DataType = DataType_; + static_assert(std::is_same_v); + static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window"); + using WindowOrigin = multi_index<2>; + + CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t* physical_blocks_, + long_index_t block_stride_, + long_index_t fixed_offset_, + const int32_t* physical_block_indices_, + index_t num_blocks_, + index_t page_block_size_, + const TensorView& complete_view_, + const TensorView& last_view_) + : physical_blocks(reinterpret_cast(physical_blocks_)), + block_stride(block_stride_), + fixed_offset(fixed_offset_), + physical_block_indices(physical_block_indices_), + num_blocks(num_blocks_), + page_block_size(page_block_size_), + complete_view(complete_view_), + last_view(last_view_) + { + } + + template + CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin) const + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, + window_lengths, + local_window_origin); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + + template + CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin, + const TileDistribution& tile_distribution) const + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, + window_lengths, + local_window_origin, + tile_distribution); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + + ck_tile::move_tile_window(tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, tile_window.get_window_origin()); + const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); + + const index_t new_block_index = get_block_index(global_window_origin); + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + + return new_block_index; + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step, + index_t id) const + { + ck_tile::move_tile_window(tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, tile_window.get_window_origin()); + const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); + + const index_t new_block_index = get_block_index(global_window_origin); + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + if(id >= 0) + tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride + + fixed_offset); + else + tile_window.set_bottom_tensor_view_data_ptr(nullptr); + + return new_block_index; + } + + template + CK_TILE_HOST_DEVICE index_t + prefetch_table_id(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + auto local_tile_window = tile_window; // not affect origin window + ck_tile::move_tile_window(local_tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, local_tile_window.get_window_origin()); + const index_t new_block_index = get_block_index(global_window_origin); + + if(new_block_index < num_blocks) + { + return physical_block_indices[new_block_index]; + } + else + { + return -1; + } + } + + CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const + { + return block_index == num_blocks - 1; + } + + template + CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, + const TileWindow& tile_window) const + { + const index_t origin = tile_window.get_window_origin().at(number{}); + const index_t length = tile_window.get_window_lengths().at(number{}); + return (block_index < num_blocks - 1) && (page_block_size < origin + length); + } + + template + CK_TILE_HOST_DEVICE void + move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const + { + const multi_index<2> step = [&]() { + const index_t origin_diff = (block_index - new_block_index) * page_block_size; + if constexpr(VirtualDim == 0) + { + return make_multi_index(origin_diff, 0); + } + else + { + return make_multi_index(0, origin_diff); + } + }(); + + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(tile_window.get_window_origin() + step); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + } + + CK_TILE_HOST_DEVICE WindowOrigin + to_local_window_origin(const WindowOrigin& global_window_origin) const + { + if constexpr(VirtualDim == 0) + { + const index_t length = global_window_origin.at(number<0>{}); + const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); + return make_multi_index(length - page_block_size * num_complete_blocks, + global_window_origin.at(number<1>{})); + } + else + { + const index_t length = global_window_origin.at(number<1>{}); + const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); + return make_multi_index(global_window_origin.at(number<0>{}), + length - page_block_size * num_complete_blocks); + } + } + + CK_TILE_HOST_DEVICE WindowOrigin + to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const + { + if constexpr(VirtualDim == 0) + { + return make_multi_index(block_index * page_block_size + + local_window_origin.at(number<0>{}), + local_window_origin.at(number<1>{})); + } + else + { + return make_multi_index(local_window_origin.at(number<0>{}), + block_index * page_block_size + + local_window_origin.at(number<1>{})); + } + } + + private: + CK_TILE_HOST_DEVICE + DataType* get_block_ptr(index_t block_index) const + { + if(block_index < num_blocks) + { + return physical_blocks + physical_block_indices[block_index] * block_stride + + fixed_offset; + } + else + { + return nullptr; + } + } + + CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const + { + return integer_divide_floor(global_window_origin.at(number{}), page_block_size); + } + + DataType* physical_blocks; + long_index_t block_stride; + long_index_t fixed_offset; + + const int32_t* physical_block_indices; + index_t num_blocks; + index_t page_block_size; + + TensorView complete_view; + TensorView last_view; +}; + +template +CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view) +{ + return TrivialPageBlockNavigator(tensor_view); +} + +template +CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t* physical_blocks, + long_index_t block_stride, + long_index_t fixed_offset, + const int32_t* physical_block_indices, + index_t num_blocks, + index_t page_block_size, + const TensorView& complete_view, + const TensorView& last_view) +{ + return PageBlockNavigator(physical_blocks, + block_stride, + fixed_offset, + physical_block_indices, + num_blocks, + page_block_size, + complete_view, + last_view); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/variants.hpp b/include/ck_tile/ops/unified_attention/block/variants.hpp new file mode 100644 index 0000000000..d8b0cdbb86 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/variants.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include + +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0 +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1 + +#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH +#endif + +#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM +#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 +#endif + +namespace ck_tile { +namespace internal { +__device__ inline float +exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp) +{ +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + /// NOTICE: Make sure softmax_scale is stored in SGPR + float result, numerator, denominator; + asm volatile( + "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" + "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" + "v_rcp_f32_e32 %[denominator], %[denominator]\n" + "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" + "v_mul_f32_e32 %[result], %[numerator], %[denominator]" + : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result) + : [softmax_scale] "s"(softmax_scale), + [logits] "v"(logits), + [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp)); + return result; +#else + return softmax_scale * logits * rcp(1.f + abs(logits * logits_soft_cap_rcp)); +#endif +} +} // namespace internal + +template +struct StandardAttentionParams +{ + __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_) + : impl_mask(impl_mask_), sm_scale(sm_scale_) + { + } + + const ImplMask& impl_mask; + float sm_scale; +}; + +template +struct LogitsSoftCapParams +{ + __device__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap); + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __host__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_, + float sm_scale_, + float logits_soft_cap_, + float logits_soft_cap_rcp_) + : impl_mask(impl_mask_), + sm_scale(sm_scale_), + logits_soft_cap(logits_soft_cap_), + logits_soft_cap_rcp(logits_soft_cap_rcp_) + { + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + const ImplMask& impl_mask; + float sm_scale; + float logits_soft_cap; + float logits_soft_cap_rcp; +}; + +struct StandardAttention +{ + __device__ __host__ StandardAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +template +struct LogitsSoftCap +{ + __device__ __host__ LogitsSoftCap() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(UseExp2) + { + return q; + } + else + { + return type_convert(q) * params.sm_scale; + } + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +constexpr uint32_t CUSTOM_MASK = 1U; +constexpr uint32_t SLIDING_WINDOW = 2U; +constexpr uint32_t LOGITS_SOFT_CAP = 4U; +constexpr uint32_t ALIBI = 8U; + +template +struct ComposedAttention +{ + static constexpr bool use_exp2 = UseExp2; + + static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0; + + __device__ __host__ ComposedAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(use_logits_soft_cap && UseExp2) + { + return q; + } + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(use_logits_soft_cap) + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp new file mode 100644 index 0000000000..9d164b639e --- /dev/null +++ b/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp @@ -0,0 +1,450 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" + +#include +#include + +namespace ck_tile { + +template +struct FmhaFwdV3Kernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + // The attention is default causal + struct UnifiedAttentionCommonKargs + { + const void* q_ptr; + const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + void* o_ptr; + + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t num_queries_per_kv; + // scales + float scale_s; + float scale; + float scale_k; + float scale_v; + float scale_out; + + ck_tile::index_t total_num_q_blocks; + ck_tile::index_t query_stride_0; + ck_tile::index_t query_stride_1; + ck_tile::index_t stride_k_cache_0; + ck_tile::index_t stride_k_cache_1; + ck_tile::index_t stride_k_cache_2; + ck_tile::index_t stride_k_cache_3; + ck_tile::index_t stride_v_cache_0; + ck_tile::index_t stride_v_cache_1; + ck_tile::index_t stride_v_cache_2; + ck_tile::index_t stride_v_cache_3; + ck_tile::index_t output_stride_0; + ck_tile::index_t output_stride_1; + ck_tile::index_t HEAD_SIZE_PADDED; + }; + + + struct UnifiedAttentionVarlenKargs + { + const int32_t* block_tables_ptr; + const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* query_start_len_ptr; // [num_seqs+1] + + ck_tile::index_t num_seqs; // number of batches for q + ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent???? + ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent???? + }; + + struct Kargs { + UnifiedAttentionCommonKargs unifiedAttentionCommonKargs; + UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs; + }; + + // using Kargs = FmhaFwdGroupModeKargs; + + CK_TILE_HOST static constexpr Kargs MakeKargs( + const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* o_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t num_queries_per_kv, + float scale_s, + float scale, + float scale_k, + float scale_v, + float scale_out, + ck_tile::index_t total_num_q_blocks, + ck_tile::index_t query_stride_0, + ck_tile::index_t query_stride_1, + ck_tile::index_t stride_k_cache_0, + ck_tile::index_t stride_k_cache_1, + ck_tile::index_t stride_k_cache_2, + ck_tile::index_t stride_k_cache_3, + ck_tile::index_t stride_v_cache_0, + ck_tile::index_t stride_v_cache_1, + ck_tile::index_t stride_v_cache_2, + ck_tile::index_t stride_v_cache_3, + ck_tile::index_t output_stride_0, + ck_tile::index_t output_stride_1, + const int32_t* block_tables_ptr, + const int32_t* seq_lens_ptr, + const int32_t* query_start_len_ptr, + ck_tile::index_t num_seqs, + ck_tile::index_t BLOCK_SIZE, + ck_tile::index_t BLOCK_Q + ) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + hdim_q, + hdim_v, + num_head_q, + num_queries_per_kv, + static_cast(scale_s * ck_tile::log2e_v<>), + scale, + scale_k, + scale_v, + scale_out, + total_num_q_blocks, + query_stride_0, + query_stride_1, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + output_stride_0, + output_stride_1}, + { + block_tables_ptr, + seq_lens_ptr, + query_start_len_ptr, + num_seqs, + BLOCK_SIZE, + BLOCK_Q, + }}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, + ck_tile::index_t total_num_q_blocks) + { + return dim3(num_kv_heads * total_num_q_blocks, 0, 0); + } + + // CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads, + // ck_tile::index_t total_num_q_blocks) + // { + // // TODO: fix 3D grid + // return dim2(num_kv_heads, total_num_q_blocks); + // } + + // Binary search to find the sequence index for a given target index + CK_TILE_DEVICE static constexpr ck_tile::index_t + find_seq_idx(const int32_t* query_start_len_ptr, + ck_tile::index_t target_idx, + ck_tile::index_t num_seqs, + ck_tile::index_t BLOCK_Q, + bool use_q_block_mode) + { + ck_tile::index_t left = 0; + ck_tile::index_t right = num_seqs; + + while (left < right) + { + ck_tile::index_t mid = (left + right) / 2; + ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t mid_val = use_q_block_mode ? (val / BLOCK_Q + mid) : val; + + if (mid_val <= target_idx) + { + left = mid + 1; + } + else + { + right = mid; + } + } + + return left - 1; + } + + CK_TILE_DEVICE static constexpr auto + RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) + { + using namespace ck_tile; + + constexpr index_t NUM_XCDS = 8; + const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks * + (kargs.unifiedAttentionCommonKargs.num_head_q); + + // Number of pids per XCD in the new arrangement + const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; + + // When GRID_MN cannot divide NUM_XCDS, some xcds will have + // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds + index_t tall_xcds = GRID_MN % NUM_XCDS; + tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; + + // Compute current XCD and local pid within the XCD + const index_t xcd = pid % NUM_XCDS; + const index_t local_pid = pid / NUM_XCDS; + + // Calculate new pid based on the new grouping + index_t remapped_pid = 0; // Initialize to avoid constexpr error + if(xcd < tall_xcds) + { + remapped_pid = xcd * pids_per_xcd + local_pid; + } + else + { + remapped_pid = tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid; + } + + return remapped_pid; + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) + { + using namespace ck_tile; + + ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks; + // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, + // FmhaPipeline::kN1); + + const index_t i_tile_m = pid % total_num_q_blocks; // Query block index + const index_t i_tile_n = pid / total_num_q_blocks; // Head index + + return ck_tile::make_tuple(i_tile_m, i_tile_n); + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using namespace ck_tile; + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + ck_tile::index_t pid = blockIdx.x; + + pid = RemapTileIndices(pid, kargs); + + // divide problem + const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); + + const index_t seq_idx = find_seq_idx( + kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true + ); // which batch + + const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + + const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + + const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]); + + const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + + // TODO check if we get the block size info from pipeline + if (q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q >= cur_batch_query_len) { + return; + } + + const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q; + + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + + static_cast(kv_head_idx) * kargs.unifiedAttentionCommonKargs.num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1 + + static_cast(cur_batch_in_all_start_index) * kargs.unifiedAttentionCommonKargs.query_stride_0; + // const KDataType* k_ptr = + // reinterpret_cast(kargs.k_ptr) + + // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + // batch_offset_k; + // const VDataType* v_ptr = + // reinterpret_cast(kargs.v_ptr) + + // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + // batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {0, i_n1}); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto o_acc_tile = [&]() { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + kargs.scale_s, + smem_ptr); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp new file mode 100644 index 0000000000..b151b61028 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -0,0 +1,1258 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +#define ENABLE_ASM_MARKER 1 +#if ENABLE_ASM_MARKER +#define ASM_MARKER(marker) \ + __builtin_amdgcn_sched_barrier(0); \ + asm volatile("; [POYENC] " #marker); \ + __builtin_amdgcn_sched_barrier(0); +#else +#define ASM_MARKER(marker) +#endif + +#define ADD_SBARRIER_FOR_PHASE0 1 +#if !defined(CK_TILE_DISABLE_PACKED_FP32) +#define CK_TILE_DISABLE_PACKED_FP32 0 +#endif + +#define WARP_ID 0 +#define LANE_ID 0 + +#define ENABLE_DEBUG_STMTS 1 +#if ENABLE_DEBUG_STMTS +#define DEBUG_STMTS \ + if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) +#else +#define DEBUG_STMTS if constexpr(false) +#endif + +namespace ck_tile { + +template +struct CoreLoopScheduler; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + } + else + { + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + } + else + { + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +namespace detail { +CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) +{ +#if CK_TILE_DISABLE_PACKED_FP32 + return a * b + c; +#else + float result; + asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]" + : [result] "=v"(result) + : [a] "v"(a), [b] "s"(b), [c] "v"(c)); + return result; +#endif +} + +CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) +{ + fp16x2_t result; + asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b) +{ + bf16x2_t result; + asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) +{ + fp32x2_t result; + asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} +} // namespace detail + +template +struct UnifiedAttentionPipeline +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + + static_assert(std::is_same_v, + "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); + + using UnifiedAttentionShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; + + static constexpr ck_tile::index_t kM0 = UnifiedAttentionShape::kM0; + static constexpr ck_tile::index_t kN0 = UnifiedAttentionShape::kN0; + static constexpr ck_tile::index_t kK0 = UnifiedAttentionShape::kK0; + static constexpr ck_tile::index_t kN1 = UnifiedAttentionShape::kN1; + static constexpr ck_tile::index_t kK1 = UnifiedAttentionShape::kK1; + static constexpr ck_tile::index_t kQKHeaddim = UnifiedAttentionShape::kQKHeaddim; + static constexpr ck_tile::index_t kSubQKHeaddim = UnifiedAttentionShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr ck_tile::index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr ck_tile::index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr ck_tile::index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr ck_tile::index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr ck_tile::index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + return 2; + } + }(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // create another LDS buffer for p + return ck_tile::max(kM0 * kN1 * sizeof(PDataType), + Policy::template GetSmemSize() + + kM0 * kN0 * sizeof(PDataType)); + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc() + { + using namespace ck_tile; + constexpr auto lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + return lds_block_desc; + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D() + { + using namespace ck_tile; + constexpr auto lds_block_desc = make_naive_tensor_descriptor( + make_tuple(number{}), make_tuple(number<1>{}), number<1>{}, number<1>{}); + + return lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc) + { + using namespace ck_tile; + + auto tensor_view = + make_tensor_view(reinterpret_cast(base), desc); + return make_tile_window(tensor_view, desc.get_lengths(), {0, 0}); + } + + // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7 + template + CK_TILE_DEVICE static constexpr void s_waitcnt() + { + // vmcnt use bits {[15:14],[3:0]} + // expcnt use bits [6:4] + // lgkmcnt use bits [11:8] + __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) | + ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8)); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt() + { + s_waitcnt(); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt() + { + s_waitcnt<63, Lgkmcnt>(); + } + + template + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + auto s_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto s_lds_window = + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + + auto p_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto p_lds_window = + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + + auto o_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto o_lds_window = + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + + auto m_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc1D()); + [[maybe_unused]] auto m_lds_window = + make_tile_window(m_lds, make_tuple(number{}), {0}); + + const index_t warp_group_id = get_warp_id() / 4; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + auto q_dram_window = make_tile_window_linear( + q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution()); + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto k_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + auto v_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + statically_indexed_array( + nullptr, + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution())), + 2> + k_lds_window_load; + + statically_indexed_array( + nullptr, + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution())), + 2> + v_lds_window_load; + + decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())) q_tile; + + union kv_tile_type + { + CK_TILE_DEVICE kv_tile_type() {} + + decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile; + + decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; + } kv_tile; + + union sp_compute_type + { + CK_TILE_DEVICE sp_compute_type() {} + + decltype(gemm_0.MakeCBlockTile()) sp_compute; + decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())) p; + }; + statically_indexed_array sp; + + decltype(gemm_1.MakeCBlockTile()) o_acc; + constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() + // instructions should we move to fmha_alu1() + static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); + + decltype(block_tile_reduce( + sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m; + decltype(m) l; + + // initialize k_lds_window and v_lds_window + static_for<0, 2, 1>{}([&](auto idx) { + k_lds_window_load(idx) = make_tile_window( + make_lds_tile_window( + static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution()); + }); + + static_for<0, 2, 1>{}([&](auto idx) { + v_lds_window_load(idx) = + make_tile_window(make_lds_tile_window( + static_cast(smem_ptr) + + (idx + 2) * Policy::template GetSmemSizeKV(), + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution()); + }); + + { + auto origin_q = load_tile(q_dram_window); + auto transformed_q = tile_elementwise_in(q_element_func, origin_q); + + q_tile = transformed_q; + } + + clear_tile(o_acc); + set_tile(m, bit_cast(0xff7fffff)); // a bit larger than -infinity + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + index_t kv_token_start = seqlen_k_start; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + k_dram_window.init_raw(); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + v_dram_window.init_raw(); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + static_assert(1 == k0_loops); + static_assert(1 == k1_loops); + static_assert(kN0 == kK1); + + constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; + static_assert(NumWarpGroups == 2); + + [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) { + printf("[POYENC] %s (size=%d): %5.2f", + name, + decltype(dist_tensor.thread_buf_)::size(), + ck_tile::type_convert(dist_tensor.thread_buf_[0])); + static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) { + printf(", %5.2f", ck_tile::type_convert(dist_tensor.thread_buf_[i])); + }); + printf("\n"); + }; + + [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) { + const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{}); + const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + if constexpr(true || num_rows < num_cols) + { + for(int row = 0; row < num_rows; ++row) + { + int offset = desc.calculate_offset(make_tuple(row, 0)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + row, + ck_tile::type_convert(data[offset])); + for(int col = 1; col < num_cols; ++col) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + else + { + for(int col = 0; col < num_cols; ++col) + { + int offset = desc.calculate_offset(make_tuple(0, col)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + col, + ck_tile::type_convert(data[offset])); + for(int row = 1; row < num_rows; ++row) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + }; + + [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) { + const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + int offset = desc.calculate_offset(make_tuple(0)); + printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert(data[offset])); + for(int e = 1; e < num_elems; ++e) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(e)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + }; + + // K_mem_su_ld_insts = 1 for 32 x 128 + // V_mem_su_ld_insts = 1 for 128 x 32 + constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); + constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); + + auto K_mem_load = [&](auto k_lds_write_idx) { + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + + /// FIXME: use the future-predicting method to move the window + // move K tile windows + move_tile_window(k_dram_window, {kN0, 0}); + }; + + auto K_lds_load = [&](auto k_lds_read_idx) { + kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); + }; + + auto V_mem_load = [&](auto v_lds_write_idx) { + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + + /// FIXME: use the future-predicting method to move the window + move_tile_window(v_dram_window, {kK1, 0}); + }; + + auto V_lds_load = [&](auto v_lds_read_idx) { + kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + }; + + decltype(m) m_old; + SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() + /// TODO: remove the sp_delta and use sp_compute directly + statically_indexed_array{}).sp_compute), 2> sp_delta; + + auto fmha_alu0 = [&](auto sp_reg_idx) { + m_old = m; // m{j-1} + static_assert(m.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowmax value"); + auto m_latest = block_tile_reduce( + sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(m_latest.thread_buf_[0]), + bit_cast(m_latest.thread_buf_[0]), + false, + false); + /// TODO: eliminate 2 redudant v_max_f32 instructions generated by the compiler + m_latest.thread_buf_[0] = f_max(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(m_latest, f_max, bool_constant{}); +#endif + m = m_latest; + + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + }); + }); + /// TODO: move some fmha_alu1() code here if necessary + }; + + auto fmha_alu1 = [&](auto sp_reg_idx) { + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp(sp_reg_idx).sp_compute(i_j_idx) = + ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx)); + }); + }); + + auto rowsum_p = block_tile_reduce( + sp(sp_reg_idx).sp_compute, + sequence<1>{}, + f_sum, + SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + static_assert(rowsum_p.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowsum value"); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(rowsum_p.thread_buf_[0]), + bit_cast(rowsum_p.thread_buf_[0]), + false, + false); + rowsum_p.thread_buf_[0] = f_sum(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#endif + + // l{j} + /// Note: The compiler keeps moving the following instructions elsewhere because 'l' + /// is first consumed later. To anchor them here, we rewrite the final addition in + /// inline assembly to create a dependency, forcing the dependent instructions to + /// be emitted at this point. + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + + l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); + }); + + // update partial o_acc [0, fmha_alu_D_reg_cnt) + static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) { + o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale); + }); + + /// Note: The compiler keeps sinking the conversion instructions because the + /// result 'p' is only consumed later. To anchor them here, we rewrite + /// the cast_tile() call as inline assembly, forcing the conversions to be + /// emitted at this point. + static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); + static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { + float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); + float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); + if constexpr(std::is_same_v) + { + auto casted = detail::cvt_pk_fp16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + else + { + auto casted = detail::cvt_pk_bf16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + }); + + /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly + /// can interfere with the behavior of sched_group_barrier(), so ending the phase here + /// avoids unintended reordering. + }; + + auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + } + }; + + auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + fmha_alu0(number<1>{} - sp_reg_idx); + } + }; + + auto fmha_alu_D_upd = [&] { + o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + + fp32x2_t pk_o_acc_scale; + pk_o_acc_scale.x = o_acc_scale; + pk_o_acc_scale.y = o_acc_scale; + + static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0); +#if CK_TILE_DISABLE_PACKED_FP32 + static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size()); + static_for{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); +#endif + + constexpr auto issued_D_reg_cnt = +#if CK_TILE_DISABLE_PACKED_FP32 + fmha_alu_D_reg_cnt + 2 +#else + fmha_alu_D_reg_cnt +#endif + ; + /// NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call + /// should be placed at the end of a phase. + // update partial o_acc after [issued_D_reg_cnt] + static_for{}([&](auto idx) { + fp32x2_t input; + input.x = o_acc.thread_buf_[idx]; + input.y = o_acc.thread_buf_[idx + 1]; + + auto output = detail::pk_mul_f32(input, pk_o_acc_scale); + + o_acc.thread_buf_[idx] = output.x; + o_acc.thread_buf_[idx + 1] = output.y; + }); + }; + + auto fmha_mask = [&](auto sp_reg_idx) { + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(sp(sp_reg_idx).sp_compute, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = kv_token_start + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + }; + + auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) { + if constexpr(load_type == 0) + { + V_mem_load(mem_wr_idx); + K_lds_load(lds_rd_idx); + } + else + { + K_mem_load(mem_wr_idx); + V_lds_load(lds_rd_idx); + } + }; + + auto core_loop = [&](auto cl_p) { + auto gemm0 = number<0>{}; + auto gemm1 = number<1>{}; + + auto memV = number<0>{}; + auto memK = number<1>{}; + + using Scheduler = CoreLoopScheduler; + + auto iteration = [&](auto pi) { + auto xdl_SP_p01_reg_idx = number<1>{} - pi; + auto xdl_SP_p23_reg_idx = pi; + + auto K_w0_lds_wr_idx = number<1>{} - pi; + auto V_w0_lds_wr_idx = pi; + auto K_w0_lds_rd_idx = pi; + auto V_w0_lds_rd_idx = pi; + + auto K_w4_lds_wr_idx = number<1>{} - pi; + auto V_w4_lds_wr_idx = number<1>{} - pi; + auto K_w4_lds_rd_idx = number<1>{} - pi; + auto V_w4_lds_rd_idx = pi; + + bool result = true; + + if constexpr(cl_p == 0) + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave0-3 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave0-3 (pi=1)"); + } + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + Scheduler::schedule(cl_p, number<1>{}); + fmha_mask(xdl_SP_p01_reg_idx); + + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave0-3"); + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 0"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<2>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); + + Scheduler::schedule(cl_p, number<3>{}); + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + } + else + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave4-7 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave4-7 (pi=1)"); + } + cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave4-7"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + Scheduler::schedule(cl_p, number<2>{}); + fmha_mask(xdl_SP_p01_reg_idx); + + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<3>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + } + return result; + }; + return iteration(number<0>{}) && iteration(number<1>{}); + }; + + auto fmha_post_process = [&](auto d) { + auto ps_pi = number<1>{} - d; + auto V_lds_rd_idx = ps_pi; + + if(1 < num_total_loop) + { + s_waitcnt_vmcnt(); + } + else + { + s_waitcnt_vmcnt<0>(); + } + __builtin_amdgcn_s_barrier(); + + V_lds_load(V_lds_rd_idx); + fmha_alu1(ps_pi); + + s_waitcnt_lgkmcnt<0>(); + + auto xdl_SP_p23_reg_idx = ps_pi; + gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); + }; + + // pre-stage + { + ASM_MARKER("before pre-stage"); + // (1) load K0 to LDS & VGPR + K_mem_load(number<0>{}); // mem_K0 + + s_waitcnt_vmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + K_lds_load(number<0>{}); // lds_K0 + + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + // (2) prefetch K1 and V0 to LDS in parallel with GEMM0 + if(1 < num_total_loop) + { + K_mem_load(number<1>{}); // mem_K1 + } + V_mem_load(number<0>{}); // mem_V0 + + // (3) mfma (Q*K0) + softmax + gemm(number<0>{}, /*gemm_idx=*/number<0>{}); + + fmha_mask(number<0>{}); + /// TODO: find better way to map fmha_alu(0,96) call + fmha_alu0(number<0>{}); + fmha_alu_D_upd(); + + kv_token_start += kN0; + ++i_total_loops; + if(num_total_loop <= i_total_loops) + { + goto label_main_loops_exit; + } + + if(2 < num_total_loop) + { + K_mem_load(number<0>{}); // mem_K2 + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + } + + ASM_MARKER("end pre-stage"); + } + + if(1 < num_total_loop) + { + if(warp_group_id == 0) + { + V_mem_load(number<1>{}); // V1 + K_lds_load(number<1>{}); // K1 + + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<0>{})) + ; + } + if(warp_group_id != 0) + { + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<1>{})) + ; + } + } + label_main_loops_exit: + if(num_total_loop % 2) + { + fmha_post_process(number<1>{}); + } + if(!(num_total_loop % 2)) + { + fmha_post_process(number<0>{}); + } + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp new file mode 100644 index 0000000000..bfbb1a93f0 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -0,0 +1,603 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" + +namespace ck_tile { + +struct UnifiedAttentionPipelineDefaultPolicy +{ + static constexpr ck_tile::index_t NumWarpPerGroup = 4; + static constexpr ck_tile::index_t NumThreadPerWarpGroup = + NumWarpPerGroup * ck_tile::get_warp_size(); + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentK() + { + using namespace ck_tile; + using KDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(KDataType); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentV() + { + using namespace ck_tile; + using VDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentV(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + // compute the endcoding before transpose + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto GetQKBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, + typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + + return BlockGemmARegBRegCRegV2{}; + } + + template + CK_TILE_DEVICE static constexpr auto GetPVBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, + typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; + /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass + /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single + using WarpGemm = WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + return BlockGemmARegBRegCRegV2{}; + } + + static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + + template + CK_TILE_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + + template + CK_TILE_DEVICE static constexpr auto + MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return v_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() + { + using namespace ck_tile; + + static_assert(MakeKLdsLoadBlockDescriptor().get_element_space_size() == + MakeKLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t k_element_space_size = + MakeKLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(MakeVLdsLoadBlockDescriptor().get_element_space_size() == + MakeVLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t v_element_space_size = + MakeVLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <= + GetSingleSmemElementSpaceSize()); + + /// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() & + /// MakeVLdsBlockDescriptor() + static_assert(std::is_same_v); + constexpr index_t kv_element_space_size_in_bytes = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return kv_element_space_size_in_bytes; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 4 * GetSmemSizeKV(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp new file mode 100644 index 0000000000..45a1c8f4b8 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaPipelineEnum +{ + QRKSVS = 0, + QRKSVS_ASYNC, + QSKSVS, + QRKSVS_ASYNC_TRLOAD, +}; + +template +struct BlockFmhaPipelineEnumToStr; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qs"; +}; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async_trload"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp new file mode 100644 index 0000000000..8c8ccc3bd2 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" + +namespace ck_tile { + +template +struct UnifiedAttentionPipelineProblem +{ + // TODO kM0 and KN1?? + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + // first gemm accumulation dtype + using SaccDataType = remove_cvref_t; + // Softmax dtype + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + // data type for A matrix of second gemm + using PDataType = remove_cvref_t; + // data type for second gemm accumulation + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using UnifiedAttentionShape = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps; + static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size(); + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; + static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; +} \ No newline at end of file