diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 3d32c40535..f17aa3a31e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -15,8 +15,10 @@ #include "hstu_attention_hdim_switch.hpp" #include "hstu_attention_pipeline_problem.hpp" #include "hstu_attention_traits.hpp" -#include "hstu_attention_fwd_pipeline.hpp" -#include "hstu_attention_fwd_trload_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp" #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" @@ -66,38 +68,52 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access constexpr bool kPadSeqLenQ = false; - BOOL_SWITCH_3(pad_seqlen_k, - kPadSeqLenK, - pad_headdim_qk, - kPadHeadDimQK, - pad_headdim_v, - kPadHeadDimV, - [&] { - using HstuTraits = ck_tile::HstuAttentionFwdTraits; + BOOL_SWITCH_3( + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_qk, + kPadHeadDimQK, + pad_headdim_v, + kPadHeadDimV, + [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; - using HstuPipelineProblem = HstuPipelineProblemTemp; + using HstuPipelineProblem = HstuPipelineProblemTemp; - using HstuEpilogue = - ck_tile::NRepetitions2DEpilogue::OaccDataType, - typename HstuAttentionFwdTypeConfig::ODataType, - kPadSeqLenQ, - kPadHeadDimV>>; + using HstuEpilogue = + ck_tile::NRepetitions2DEpilogue::OaccDataType, + typename HstuAttentionFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; - using HstuPipeline = std::conditional_t< - kUseTrLoad, - ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad, - ck_tile::HstuAttentionFwdPipelineQRKSVS>; + if constexpr(!kUseTrLoad) + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS>; - using HstuKernel = - ck_tile::HstuAttentionFwdKernel; + using HstuKernel = ck_tile::HstuAttentionFwdKernel; - RunWithKernel(param, stream); - }); + RunWithKernel(param, stream); + } + else + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad>; + + using HstuKernel = ck_tile::HstuAttentionFwdKernel; + + RunWithKernel(param, stream); + }; + }); }; template diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 42d4ae405c..d72526d1b3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -15,8 +15,10 @@ #include "hstu_attention_hdim_switch.hpp" #include "hstu_attention_pipeline_problem.hpp" #include "hstu_attention_traits.hpp" -#include "hstu_attention_fwd_pipeline.hpp" -#include "hstu_attention_fwd_trload_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp" #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" @@ -82,13 +84,28 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch kPadSeqLenQ, kPadHeadDimV>>; - using HstuPipeline = std::conditional_t< - kUseTrLoad, - ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad, - ck_tile::HstuAttentionFwdPipelineQRKSVS>; - using HstuKernel = ck_tile::HstuAttentionFwdKernel; + if constexpr(!kUseTrLoad) + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS>; - RunWithKernel(param, stream); + using HstuKernel = ck_tile::HstuAttentionFwdKernel; + + RunWithKernel(param, stream); + } + else + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad>; + + using HstuKernel = ck_tile::HstuAttentionFwdKernel; + + RunWithKernel(param, stream); + }; }); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp new file mode 100644 index 0000000000..7b23bbe015 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -0,0 +1,549 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" + +#include "hstu_attention_fwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +template +struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QKVDataType = remove_cvref_t; + using GemmAccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + using HstuAttentionTileSetting = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; + static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; + static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; + static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax"); + + static constexpr bool kIsJagged = Problem::kIsJagged; + static constexpr auto kHasBias = Problem::kHasBias; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasCausal = Problem::kHasCausal; + + static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK; + static constexpr bool kPadHeadDimV = + (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); + static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; + + // used by NRepetitions2DEpilogue + static constexpr index_t kGemm1SingleRepN = + Policy::template GetKVBlockGemmSingleRepN(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::Traits::kBlockPerCu != -1) + return Problem::Traits::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(kHasBias) + return 2; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_hstu"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + HstuMask& mask, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLu result + void* smem_ptr, + DropoutType& dropout) const + { + ignore = q_element_func; + ignore = k_element_func; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr index_t k1_loops = kN0 / kK1; + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramSingleRepMTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + + using q_dram_tile_type = decltype(load_tile(q_dram_window)); + statically_indexed_array q_dram_tiles; + + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + q_dram_tiles[i_rep] = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); + }); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + statically_indexed_array k_tiles; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); + + __builtin_amdgcn_sched_barrier(0); + + // Q tile in LDS + QKVDataType* q_lds_ptr = static_cast(smem_ptr); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_write_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto q_lds_read_window = + make_tile_window(q_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQRegSingleRepMTileDistribution()); + + // K tile in LDS + QKVDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_write_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto k_lds_read_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + using k_lds_write_window_type = decltype(get_slice_tile( + k_lds_write_window, sequence<0, 0>{}, sequence{})); + + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_read_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_write_windows; + statically_indexed_array k_lds_read_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_write_windows[i_buf] = + get_slice_tile(k_lds_write_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + }); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // reduction function for softmax + const auto f_silu = [&](CompDataType& x) { + const auto one = ck_tile::type_convert(1.0f); + + if constexpr(std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(one + __expf(-x)); + } + else + { + x = x / (one + exp(-x)); + } + }; + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + using q_reg_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegSingleRepMTileDistribution())); + statically_indexed_array q_reg_tiles; + + using q_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())); + + q_tile_type q_tile; + + { + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + store_tile(q_lds_write_window, q_dram_tiles[i_rep]); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written + // by each wavefront is read by itself + __builtin_amdgcn_s_waitcnt(0xc07f); + + q_reg_tiles[i_rep] = load_tile(q_lds_read_window); + + __builtin_amdgcn_s_waitcnt(0xc07f); + + // the following codes will not generate actual instructions by the compiler + set_slice_tile(q_tile, + q_reg_tiles[i_rep], + sequence{}, + sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read + // by each wavefront is over-written by itself + }); + + clear_tile(o_acc); + }; + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + auto seqlen_k_curr = seqlen_k_start; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile + __builtin_amdgcn_s_barrier(); + + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + do + { + // STAGE 1, Gemm_0 ( S = Q@K ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_write_windows[i_k1], + tile_elementwise_in(k_element_func, k_tiles[i_k1])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + + auto tmp_tile = cast_tile(sacc_tile); + + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 2, scale_s, add bias, mask, siLU + if constexpr(kHasBias) + { + const auto bias_tile = load_tile(bias_dram_window); + + tile_elementwise_inout( + [&scale_s, &bias_element_func](auto& x, const auto& y) { + x = x * scale_s + type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + + move_tile_window(bias_dram_window, {0, kN0}); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(!mask.IsTokenPairInsideMask(row, col)) + { + pcomp_tile(i_j_idx) = type_convert(0.0f); + }; + }); + }); + } + + tile_elementwise_inout(f_silu, pcomp_tile); + + tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, + pcomp_tile); + + seqlen_k_curr += kN0; + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + using v_shuffled_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution())); + + statically_indexed_array v_shuffled_tiles; + + static_for<0, k1_loops, 1>{}( + [&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); }); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + HstuMask mask, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLU result + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + scale_p, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp new file mode 100644 index 0000000000..7d4a1147f7 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -0,0 +1,540 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" + +#include "hstu_attention_fwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +template +struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QKVDataType = remove_cvref_t; + using GemmAccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + using HstuAttentionTileSetting = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; + static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; + static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; + static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax"); + + static constexpr bool kIsJagged = Problem::kIsJagged; + static constexpr auto kHasBias = Problem::kHasBias; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasCausal = Problem::kHasCausal; + + static_assert(Problem::kUseTrLoad == true, "Check failed!"); + + static constexpr bool kUseTrLoad = true; + + static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK; + static constexpr bool kPadHeadDimV = + (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); + static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; + + // used by NRepetitions2DEpilogue + static constexpr index_t kGemm1SingleRepN = + Policy::template GetKVBlockGemmSingleRepN(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::Traits::kBlockPerCu != -1) + return Problem::Traits::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(kHasBias) + return 2; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_hstu"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + HstuMask& mask, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLu result + void* smem_ptr, + DropoutType& dropout) const + { + ignore = q_element_func; + ignore = k_element_func; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr index_t k1_loops = kN0 / kK1; + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramSingleRepMTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + + using q_dram_tile_type = decltype(load_tile(q_dram_window)); + statically_indexed_array q_dram_tiles; + + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + q_dram_tiles[i_rep] = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); + }); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + statically_indexed_array k_tiles; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); + + __builtin_amdgcn_sched_barrier(0); + + // Q tile in LDS + QKVDataType* q_lds_ptr = static_cast(smem_ptr); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_write_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto q_lds_read_window = + make_tile_window(q_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQRegSingleRepMTileDistribution()); + + // K tile in LDS + QKVDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_write_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto k_lds_read_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + using k_lds_write_window_type = decltype(get_slice_tile( + k_lds_write_window, sequence<0, 0>{}, sequence{})); + + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_read_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_write_windows; + statically_indexed_array k_lds_read_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_write_windows[i_buf] = + get_slice_tile(k_lds_write_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + }); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); + + // reduction function for softmax + const auto f_silu = [&](CompDataType& x) { + const auto one = ck_tile::type_convert(1.0f); + + if constexpr(std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(one + __expf(-x)); + } + else + { + x = x / (one + exp(-x)); + } + }; + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + using q_reg_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegSingleRepMTileDistribution())); + statically_indexed_array q_reg_tiles; + + using q_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())); + + q_tile_type q_tile; + + { + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + store_tile(q_lds_write_window, q_dram_tiles[i_rep]); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written + // by each wavefront is read by itself + __builtin_amdgcn_s_waitcnt(0xc07f); + + q_reg_tiles[i_rep] = load_tile(q_lds_read_window); + + __builtin_amdgcn_s_waitcnt(0xc07f); + + // the following codes will not generate actual instructions by the compiler + set_slice_tile(q_tile, + q_reg_tiles[i_rep], + sequence{}, + sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read + // by each wavefront is over-written by itself + }); + + clear_tile(o_acc); + }; + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + auto seqlen_k_curr = seqlen_k_start; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile + __builtin_amdgcn_s_barrier(); + + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + do + { + // STAGE 1, Gemm_0 ( S = Q@K ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_write_windows[i_k1], + tile_elementwise_in(k_element_func, k_tiles[i_k1])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + + auto tmp_tile = cast_tile(sacc_tile); + + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 2, scale_s, add bias, mask, siLU + if constexpr(kHasBias) + { + const auto bias_tile = load_tile(bias_dram_window); + + tile_elementwise_inout( + [&scale_s, &bias_element_func](auto& x, const auto& y) { + x = x * scale_s + type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + + move_tile_window(bias_dram_window, {0, kN0}); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(!mask.IsTokenPairInsideMask(row, col)) + { + pcomp_tile(i_j_idx) = type_convert(0.0f); + }; + }); + }); + } + + tile_elementwise_inout(f_silu, pcomp_tile); + + tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, + pcomp_tile); + + seqlen_k_curr += kN0; + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[number{}])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + __builtin_amdgcn_sched_barrier(0x00000001); + + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + }); + } while(seqlen_k_curr < seqlen_k_end); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + HstuMask mask, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLU result + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + scale_p, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp similarity index 77% rename from example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp rename to example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 1245e4e445..4adb7bf00c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -11,7 +11,7 @@ namespace ck_tile { template -struct HstuAttentionFwdPipelineQRKSVS +struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { using Problem = remove_cvref_t; using Policy = remove_cvref_t; @@ -35,6 +35,8 @@ struct HstuAttentionFwdPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax"); + static constexpr bool kIsJagged = Problem::kIsJagged; static constexpr auto kHasBias = Problem::kHasBias; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -143,6 +145,7 @@ struct HstuAttentionFwdPipelineQRKSVS { ignore = q_element_func; ignore = k_element_func; + ignore = scale_p; static_assert( std::is_same_v> && @@ -160,8 +163,6 @@ struct HstuAttentionFwdPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - constexpr bool kUseSoftmax = Problem::kUseSoftmax; - constexpr index_t k1_loops = kN0 / kK1; constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); @@ -293,20 +294,6 @@ struct HstuAttentionFwdPipelineQRKSVS {0, seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); - // reduction function for softmax - const auto f_silu = [&](CompDataType& x) { - const auto one = ck_tile::type_convert(1.0f); - - if constexpr(std::is_same_v) - { - x = x * __builtin_amdgcn_rcpf(one + __expf(-x)); - } - else - { - x = x / (one + exp(-x)); - } - }; - const auto f_exp = [&](CompDataType x) { if constexpr(std::is_same_v) { @@ -381,11 +368,8 @@ struct HstuAttentionFwdPipelineQRKSVS clear_tile(o_acc); - if constexpr(kUseSoftmax) - { - set_tile(m, -numeric::infinity()); - clear_tile(l); - }; + set_tile(m, -numeric::infinity()); + clear_tile(l); }; q_tile = tile_elementwise_in(q_element_func, q_tile); @@ -454,129 +438,98 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); } - if constexpr(!kUseSoftmax) + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) { - if(!mask.IsFullTileInsideMask( - q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) - { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(!mask.IsTokenPairInsideMask(row, col)) - { - pcomp_tile(i_j_idx) = type_convert(0.0f); - }; - }); + if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; }); - } - - tile_elementwise_inout(f_silu, pcomp_tile); - - tile_elementwise_inout( - [&](auto& x) { x = x * type_convert(scale_p); }, pcomp_tile); + }); } else { - if(!mask.IsFullTileInsideMask( - q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; + }); + }); + }; + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end) - { - pcomp_tile(i_j_idx) = -numeric::infinity(); - }; - }); + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); }); } else { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - if(col >= seqlen_k_end) - { - pcomp_tile(i_j_idx) = -numeric::infinity(); - }; - }); + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); }); - }; + } + }); - auto m_local = block_tile_reduce( - pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); - block_tile_reduce_sync(m_local, f_max, bool_constant{}); + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); - const auto m_old = m; + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); - constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - - if(m[i_idx] == -numeric::infinity()) - { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - pcomp_tile(i_j_idx) = type_convert(0.0f); - }); - } - else - { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); - }); - } - }); - - auto rowsum_p = block_tile_reduce( - pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); - - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - - // adjust o_acc[] according to the update between m and m_old - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - - if(m[i_idx] == -numeric::infinity()) - { - l(i_idx) = rowsum_p[i_idx]; - } - else - { - const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - } - }); - }; + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); seqlen_k_curr += kN0; @@ -635,22 +588,19 @@ struct HstuAttentionFwdPipelineQRKSVS }; } while(seqlen_k_curr < seqlen_k_end); - if constexpr(kUseSoftmax) - { - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(m[i_idx] == -numeric::infinity()) - o_acc(i_j_idx) = 0.0f; - else - o_acc(i_j_idx) *= 1.0f / l[i_idx]; - }); + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; }); - }; + }); o_acc = tile_elementwise_in(o_acc_element_func, o_acc); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp similarity index 77% rename from example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp rename to example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 140ef8f07b..7b0262c598 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -11,7 +11,7 @@ namespace ck_tile { template -struct HstuAttentionFwdPipelineQRKSVSTrLoad +struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad { using Problem = remove_cvref_t; using Policy = remove_cvref_t; @@ -35,6 +35,8 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax"); + static constexpr bool kIsJagged = Problem::kIsJagged; static constexpr auto kHasBias = Problem::kHasBias; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -143,6 +145,7 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad { ignore = q_element_func; ignore = k_element_func; + ignore = scale_p; static_assert( std::is_same_v> && @@ -160,8 +163,6 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - constexpr bool kUseSoftmax = Problem::kUseSoftmax; - constexpr index_t k1_loops = kN0 / kK1; constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); @@ -293,20 +294,6 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad {seqlen_k_start, 0}, Policy::template MakeVDramTileDistribution()); - // reduction function for softmax - const auto f_silu = [&](CompDataType& x) { - const auto one = ck_tile::type_convert(1.0f); - - if constexpr(std::is_same_v) - { - x = x * __builtin_amdgcn_rcpf(one + __expf(-x)); - } - else - { - x = x / (one + exp(-x)); - } - }; - const auto f_exp = [&](CompDataType x) { if constexpr(std::is_same_v) { @@ -381,11 +368,8 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad clear_tile(o_acc); - if constexpr(kUseSoftmax) - { - set_tile(m, -numeric::infinity()); - clear_tile(l); - }; + set_tile(m, -numeric::infinity()); + clear_tile(l); }; q_tile = tile_elementwise_in(q_element_func, q_tile); @@ -454,129 +438,98 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); } - if constexpr(!kUseSoftmax) + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) { - if(!mask.IsFullTileInsideMask( - q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) - { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(!mask.IsTokenPairInsideMask(row, col)) - { - pcomp_tile(i_j_idx) = type_convert(0.0f); - }; - }); + if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; }); - } - - tile_elementwise_inout(f_silu, pcomp_tile); - - tile_elementwise_inout( - [&](auto& x) { x = x * type_convert(scale_p); }, pcomp_tile); + }); } else { - if(!mask.IsFullTileInsideMask( - q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; + }); + }); + }; + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end) - { - pcomp_tile(i_j_idx) = -numeric::infinity(); - }; - }); + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); }); } else { - constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - if(col >= seqlen_k_end) - { - pcomp_tile(i_j_idx) = -numeric::infinity(); - }; - }); + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); }); - }; + } + }); - auto m_local = block_tile_reduce( - pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); - block_tile_reduce_sync(m_local, f_max, bool_constant{}); + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); - const auto m_old = m; + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); - constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - - if(m[i_idx] == -numeric::infinity()) - { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - pcomp_tile(i_j_idx) = type_convert(0.0f); - }); - } - else - { - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); - }); - } - }); - - auto rowsum_p = block_tile_reduce( - pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); - - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - - // adjust o_acc[] according to the update between m and m_old - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - - if(m[i_idx] == -numeric::infinity()) - { - l(i_idx) = rowsum_p[i_idx]; - } - else - { - const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - } - }); - }; + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); seqlen_k_curr += kN0; @@ -622,22 +575,19 @@ struct HstuAttentionFwdPipelineQRKSVSTrLoad }); } while(seqlen_k_curr < seqlen_k_end); - if constexpr(kUseSoftmax) - { - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); - if(m[i_idx] == -numeric::infinity()) - o_acc(i_j_idx) = 0.0f; - else - o_acc(i_j_idx) *= 1.0f / l[i_idx]; - }); + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; }); - }; + }); o_acc = tile_elementwise_in(o_acc_element_func, o_acc);