diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 7df548daeb..25679561a1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -22,6 +22,8 @@ #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" +#include "hstu_attention_batched_forward_splitkv_dispatch.hpp" + template ::Run(param, stream); else - batched_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + { + const bool disable_fwd_splitkv = []() { + const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV"); + if(env_p == nullptr) + return false; + return static_cast(atoi(env_p)); + }(); + + // ToDo: enable splitkv when kUseSoftmax is true + if(!disable_fwd_splitkv && !kUseSoftmax && + shall_use_splitkv(param.num_batch, param.num_head, param.seqlen_q)) + { + if constexpr(!kUseSoftmax) + { + batched_forward_splitkv_causal_softmax_bias_dropout_dispatch::Run(param, + stream); + }; + } + else + batched_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + }; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp new file mode 100644 index 0000000000..8af4deee3c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_fwd_setting.hpp" +#include "hstu_attention_fwd_splitkv_combine_setting.hpp" +#include "hstu_attention_params.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_pipeline_problem.hpp" +#include "hstu_attention_traits.hpp" +#include "hstu_attention_with_softmax_fwd_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp" +#include "hstu_attention_fwd_splitkv_kernel.hpp" +#include "hstu_attention_fwd_splitkv_combine_kernel.hpp" + +template +struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch +{ + static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!"); + static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!"); + + using HstuAttentionFwdTileSetting = + typename std::conditional_t, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; + using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting::Type; + +#ifdef BUILD_HSTU_FOR_GFX95_ONLY + static constexpr bool kUseTrLoad = true; +#else + static constexpr bool kUseTrLoad = false; +#endif + + template + using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< + InOutDataType, + typename HstuAttentionFwdTypeConfig::GemmAccDataType, + typename HstuAttentionFwdTypeConfig::CompDataType, + typename HstuAttentionFwdTypeConfig::BiasDataType, + kIsCrossAttention, + false, // kUseGroup + false, // kIsJagged + kHasBias, + kHasDropout, + kUseCausal, + kUseSoftmax, + kUseTrLoad, + HstuAttentionFwdTileSetting>; + + using OaccDataType = HstuAttentionFwdTypeConfig::OaccDataType; + using ODataType = HstuAttentionFwdTypeConfig::ODataType; + + static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) + { + constexpr ck_tile::index_t occupancy = -1; + + { + const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionFwdTileSetting::kN0 == 0); + const bool pad_headdim_qk = + !(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + BOOL_SWITCH_3( + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_qk, + kPadHeadDimQK, + pad_headdim_v, + kPadHeadDimV, + [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; + + using HstuEpilogue = + ck_tile::Default2DEpilogue>; + + BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { + using HstuPipelineProblem = HstuFwdPipelineProblemTemp; + + if constexpr(!kUseTrLoad) + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS< + HstuPipelineProblem, + HstuTraits>, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS< + HstuPipelineProblem, + HstuTraits>>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + else + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad< + HstuPipelineProblem, + HstuTraits>, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad< + HstuPipelineProblem, + HstuTraits>>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + }; + }); + }); + }; + { + using HstuCombinePipelineProblem = + ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem< + OaccDataType, + ODataType, + false /* kIsJagged */, + kUseSoftmax, + HstuAttentionCombineTileSetting>; + const bool pad_headdim_o = + !(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] { + using HstuTraits = ck_tile:: + HstuAttentionFwdSplitKVCombineTraits; + + using HstuEpilogue = + ck_tile::Default2DEpilogue>; + + using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline< + HstuCombinePipelineProblem, + HstuTraits>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVCombineKernel; + + RunWithFwdSplitKVCombineKernel(param, stream); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) + { + param.num_splits = + get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q); + + // assume the workspace for o_acc is in compact shape of [num_batch, seqlen_q, num_head, + // num_splits, hdim] + size_t workspace_bytes = static_cast(param.num_batch) * param.seqlen_q * + param.num_head * param.num_splits * param.hdim_v * + sizeof(OaccDataType); + + HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream)); + + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.q_ptr, + param.k_ptr, + param.v_ptr, + param.bias_ptr, + param.o_acc_ptr, + param.num_splits, + param.seqlen_q, + param.is_cross_attention ? param.seqlen_kv + : param.seqlen_q, + param.hdim_qk, + param.hdim_v, + param.num_head, + param.scale_s, + param.attn_scale, + param.seq_stride_q, + param.seq_stride_k, + param.seq_stride_v, + param.seq_stride_bias, + param.nhead_stride_q, + param.nhead_stride_k, + param.nhead_stride_v, + param.nhead_stride_bias, + param.batch_stride_q, + param.batch_stride_k, + param.batch_stride_v, + param.batch_stride_bias, + param.num_targets_ptr, + param.contextual_seqlen, + param.window_size, + param.min_full_attn_seqlen, + param.p_drop, + param.philox_seed, + param.philox_offset); + }(); + + bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, + param.num_head, + param.seqlen_q, + param.hdim_v, + param.num_splits, + has_minfull_attn_seqlen); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param, + hipStream_t stream) + { + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.o_acc_ptr, + param.o_ptr, + param.batch_stride_o, + param.seq_stride_o, + param.nhead_stride_o, + param.seqlen_q, + param.num_head, + param.num_splits, + param.hdim_v); + }(); + + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen_q); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + + HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream)); + }; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 98c912fc5f..0c82f94c06 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -484,3 +484,36 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_ // mtile-64 can be added through tuning/verification return 64; }; + +static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q) +{ + int num_CUs = get_number_of_cu(); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64); + + // assume each CU can run two work-groups, common cases for hdim128 + return static_cast(nbatch_nhead_mblocks) / (2.0f * num_CUs); +}; + +static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q) +{ + // Please tune the threshold here + const float threshold = 0.8f; + + if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold) + return true; + return false; +}; + +static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q) +{ + int i = 2; + + // Please tune the threshold here + const float threshold = 1.5f; + while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold) + i++; + + return i; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp new file mode 100644 index 0000000000..b5c4593baf --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" + +#include +#include +#include +#include + +#include "hstu_block_masking.hpp" + +#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM +#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1 +#endif + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct HstuAttentionFwdSplitKVCombineKernel +{ + using HstuAttentionPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = HstuAttentionPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = HstuAttentionPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using OaccDataType = + ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged; + + static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDimO = HstuAttentionPipeline::kPadHeadDimO; + + template // to avoid duplicated base class problem, introduce an template + // arg + struct HstuAttentionCombineEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct HstuAttentionBatchedCombineBaseKargs + { + const void* o_acc_ptr; + void* o_ptr; + + ck_tile::index_t batch_stride_o; + ck_tile::index_t seq_stride_o; + ck_tile::index_t nhead_stride_o; + + ck_tile::index_t seqlen_q; + ck_tile::index_t num_head; + ck_tile::index_t num_splits; + ck_tile::index_t hdim_v; + }; + + struct HstuAttentionJaggedCombineBaseKargs + { + const void* o_acc_ptr; + void* o_ptr; + + ck_tile::index_t seq_stride_o; + ck_tile::index_t nhead_stride_o; + + const int32_t* seq_q_offsets_ptr; + ck_tile::index_t num_head; + ck_tile::index_t num_splits; + ck_tile::index_t hdim_v; + + ck_tile::index_t seqlen_q; + }; + + struct HstuAttentionBatchedCombineKargs : HstuAttentionBatchedCombineBaseKargs + { + }; + + struct HstuAttentionJaggedCombineKargs : HstuAttentionJaggedCombineBaseKargs + { + }; + + using Kargs = std:: + conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o + void* o_ptr, + ck_tile::index_t batch_stride_o, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t seqlen_q, + ck_tile::index_t num_head, + ck_tile::index_t num_splits, // number of splitted seqlen_kv + ck_tile::index_t hdim_v) + { + Kargs kargs{o_acc_ptr, + o_ptr, + batch_stride_o, + seq_stride_o, + nhead_stride_o, + seqlen_q, + num_head, + num_splits, + hdim_v}; + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o + void* o_ptr, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_o, + const void* seq_q_offsets_ptr, + ck_tile::index_t num_head, + ck_tile::index_t num_splits, // number of splitted seqlen_kv + ck_tile::index_t hdim_v) + { + Kargs kargs{o_acc_ptr, + o_ptr, + seq_stride_o, + nhead_stride_o, + reinterpret_cast(seq_q_offsets_ptr), + num_head, + num_splits, + hdim_v, + 0 /* seqlen_q will be updated later */}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_) + { + ck_tile::index_t num_tile_in_seqlen = + ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM); + +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + return dim3(batch_size_, nhead_, num_tile_in_seqlen); +#else + return dim3(num_tile_in_seqlen), + nhead_, + batch_size_); +#endif + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + ignore = kargs; + +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + const index_t i_batch = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_tile_m = blockIdx.z; +#else + const index_t i_tile_m = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; +#endif + return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(kargs); + + long_index_t batch_offset_o_acc = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsJagged) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch]; + + // assume o_acc is in compact shape of [batch_size, max_seqlen, num_head, num_splits, + // hdim] + batch_offset_o_acc = query_start * kargs.num_head * kargs.num_splits * kargs.hdim_v; + + batch_offset_o = query_start * kargs.seq_stride_o; + + kargs.seqlen_q = + kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch]; + } + else + { + // assume o_acc is in compact shape of [batch_size, seqlen_q, num_head, num_splits, + // hdim] + batch_offset_o_acc = static_cast(i_batch) * kargs.seqlen_q * + kargs.num_head * kargs.num_splits * kargs.hdim_v; + + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + index_t i_m0; + + i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM); + + if(kargs.seqlen_q <= i_m0) + return; + + // assume o_acc is in compact shape of [batch_size, seqlen, num_head, num_splits, hdim] + const OaccDataType* o_acc_ptr = + reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.num_splits * kargs.hdim_v + + batch_offset_o_acc; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Oacc DRAM and Oacc DRAM window + auto seq_stride_o_acc = kargs.num_head * kargs.num_splits * kargs.hdim_v; + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(seq_stride_o_acc, 1), + number{}, + number<1>{}); + + const auto o_acc_dram = + pad_tensor_view(o_acc_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + + auto o_acc_dram_window = + make_tile_window(o_acc_dram, + make_tuple(number{}, + number{}), + {i_m0, 0}); + + auto o_acc_tile = [&]() { + return HstuAttentionPipeline{}(o_acc_dram_window, kargs.hdim_v, kargs.num_splits); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.seq_stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view(o_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, + number{}), + {i_m0, 0}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp new file mode 100644 index 0000000000..e6260ba30f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +struct HstuAttentionFwdSplitKVCombinePipelinePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() + { + constexpr index_t kMPerBlock = Problem::kM; + constexpr index_t kKPerBlock = Problem::kOHeaddim; + constexpr index_t NumWarps = Problem::NumWarps; + + constexpr index_t KVector = GetAlignmentOacc(); + constexpr index_t OtherK = kKPerBlock / KVector; + + if constexpr(kKPerBlock == Problem::kSubOHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); + + constexpr index_t KPerThread = KVector; + + // try to assign more consecutive threads on dim-K + constexpr index_t KThreads = OtherK; + + static_assert(KThreads <= get_warp_size(), "Check failed!"); + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + // ensure KThreads be power-of-2 integer + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + return Problem::GetOaccDramTileAccessMaxVectorSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + // should be same as GetAlignmentOacc() since o_tile will use the same encoding as + // o_acc_tile + return GetAlignmentOacc(); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_setting.hpp new file mode 100644 index 0000000000..bf937477a4 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_setting.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_tile_setting_define.hpp" + +template +struct HstuAttentionFwdSplitKVCombineTileSetting; + +template <> +struct HstuAttentionFwdSplitKVCombineTileSetting<64> +{ + using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<32, 4, 64>; +}; + +template <> +struct HstuAttentionFwdSplitKVCombineTileSetting<96> +{ + using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<16, 4, 96>; +}; + +template <> +struct HstuAttentionFwdSplitKVCombineTileSetting<128> +{ + using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<16, 4, 128>; +}; + +template <> +struct HstuAttentionFwdSplitKVCombineTileSetting<256> +{ + using Type = ck_tile::HstuAttentionFwdSplitKVCombineTileSettingClass<8, 4, 256>; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp new file mode 100644 index 0000000000..36760dcd78 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -0,0 +1,1002 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" + +#include +#include +#include +#include + +#include "hstu_block_masking.hpp" + +#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM +#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1 +#endif + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct HstuAttentionFwdSplitKVKernel +{ + using HstuAttentionPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = HstuAttentionPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = HstuAttentionPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using QKVDataType = + ck_tile::remove_cvref_t; + using BiasDataType = + ck_tile::remove_cvref_t; + using OaccDataType = + ck_tile::remove_cvref_t; + + static constexpr bool kIsCrossAttention = HstuAttentionPipeline::Problem::kIsCrossAttention; + static constexpr bool kUseGroup = HstuAttentionPipeline::Problem::kUseGroup; + static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged; + static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias; + static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout; + static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal; + static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad; + + static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK; + static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV; + + template // to avoid duplicated base class problem, introduce an template + // arg + struct HstuAttentionFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct HstuAttentionNoGroupBatchedFwdBaseKargs + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + + const int32_t* num_targets_ptr; + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_acc_ptr; + ck_tile::index_t num_splits; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_kv; + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + + ck_tile::index_t num_head; + float scale_s; // scaling value exerted on the immediate Q@K result + float scale_p; // scaling value exerted on the SiLU result + + ck_tile::index_t contextual_seqlen; + ck_tile::index_t window_size; + ck_tile::index_t min_full_attn_seqlen; + }; + + struct HstuAttentionNoGroupJaggedFwdBaseKargs + { + const int32_t* seq_q_offsets_ptr; + const int32_t* seq_kv_offsets_ptr; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + + const int32_t* num_targets_ptr; + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_acc_ptr; + ck_tile::index_t num_splits; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_kv; + + ck_tile::index_t num_head; + float scale_s; // scaling value exerted on the immediate Q@K result + float scale_p; // scaling value exerted on the SiLU result + + ck_tile::index_t contextual_seqlen; + ck_tile::index_t window_size; + ck_tile::index_t min_full_attn_seqlen; + }; + + struct HstuAttentionGroupFwdBaseKargs + { + ck_tile::index_t num_batch_per_group; + + const int32_t* seq_q_offsets_ptr; + const int32_t* seq_kv_offsets_ptr; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + + const int32_t* num_targets_ptr; + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_acc_ptr; + ck_tile::index_t num_splits; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_kv; + + ck_tile::index_t num_head; + float scale_s; // scaling value exerted on the immediate Q@K result + float scale_p; // scaling value exerted on the SiLU result + + int32_t contextual_seqlen; // to be set by the per-group contextual_seqlen + int32_t window_size; // to be set by the per-group window_size + int32_t min_full_attn_seqlen; // to be set by the per-group min_full_attn_seqlen + + const int32_t* group_max_seqlen_ptr; + const int32_t* group_contextual_seqlen_ptr; + const int32_t* group_window_size_ptr; + const int32_t* group_min_full_attn_seqlen_ptr; + const float* group_attn_scale_ptr; + }; + + struct HstuAttentionFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t seq_stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct HstuAttentionFwdDropoutSeedOffset + { + uint64_t drop_seed; + uint64_t drop_offset; + }; + + struct HstuAttentionFwdCommonDropoutKargs : HstuAttentionFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed = seed; + this->drop_offset = offset; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + }; + + struct HstuAttentionNoGroupBatchedFwdKargs + : HstuAttentionNoGroupBatchedFwdBaseKargs, + std::conditional_t>, + std::conditional_t> + { + }; + + struct HstuAttentionNoGroupJaggedFwdKargs + : HstuAttentionNoGroupJaggedFwdBaseKargs, + std::conditional_t>, + std::conditional_t> + { + }; + + struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs, + std::conditional_t>, + std::conditional_t> + { + }; + + using Kargs = std::conditional_t>; + + static constexpr bool kUseNoGroupBatched = (!kUseGroup && !kIsJagged); + static constexpr bool kUseNoGroupJagged = (!kUseGroup && kIsJagged); + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_acc_ptr, // workspace for accumulation of o + ck_tile::index_t num_splits, // number of splitted seqlen_kv + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_kv, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + float attn_scale, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + const void* num_targets_ptr, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + uint64_t philox_seed, + uint64_t philox_offset) + { + Kargs kargs{ + {batch_stride_q, + batch_stride_k, + batch_stride_v, + reinterpret_cast(num_targets_ptr), + q_ptr, + k_ptr, + v_ptr, + o_acc_ptr, + num_splits, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + seqlen_q, + seqlen_kv, + hdim_qk, + hdim_v, + seq_stride_q, + seq_stride_k, + seq_stride_v, + num_head, + scale_s, + attn_scale ? attn_scale + : 1.0f / static_cast(max(seqlen_q, seqlen_kv)), // max_seqlen + contextual_seqlen, + window_size, + min_full_attn_seqlen}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dropout + }; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.seq_stride_bias = seq_stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, philox_seed, philox_offset); + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_acc_ptr, // workspace for accumulation of o + ck_tile::index_t num_splits, // number of splitted seqlen_kv + const void* seq_q_offsets_ptr, + const void* seq_kv_offsets_ptr, + ck_tile::index_t max_seqlen, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + float attn_scale, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + const void* num_targets_ptr, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + uint64_t philox_seed, + uint64_t philox_offset) + { + Kargs kargs{ + {reinterpret_cast(seq_q_offsets_ptr), + reinterpret_cast(seq_kv_offsets_ptr), + seq_stride_q, + seq_stride_k, + seq_stride_v, + reinterpret_cast(num_targets_ptr), + q_ptr, + k_ptr, + v_ptr, + o_acc_ptr, + num_splits, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + hdim_qk, + hdim_v, + -1, // seqlen_q will be updated by another pointer + -1, // seqlen_kv will be updated by another pointer + num_head, + scale_s, + attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen), + contextual_seqlen, + window_size, + min_full_attn_seqlen}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dropout + }; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.seq_stride_bias = seq_stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, philox_seed, philox_offset); + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_acc_ptr, // workspace for accumulation of o + ck_tile::index_t num_splits, // number of splitted seqlen_kv + ck_tile::index_t num_batch_per_group, + const void* seq_q_offsets_ptr, + const void* seq_kv_offsets_ptr, + const void* group_max_seqlen_ptr, + const void* group_contextual_seqlen_ptr, + const void* group_window_size_ptr, + const void* group_min_full_attn_seqlen_ptr, + const void* group_attn_scale_ptr, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + const void* num_targets_ptr, + float p_drop, + uint64_t philox_seed, + uint64_t philox_offset) + { + Kargs kargs{ + {num_batch_per_group, + reinterpret_cast(seq_q_offsets_ptr), + reinterpret_cast(seq_kv_offsets_ptr), + seq_stride_q, + seq_stride_k, + seq_stride_v, + reinterpret_cast(num_targets_ptr), + q_ptr, + k_ptr, + v_ptr, + o_acc_ptr, + num_splits, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + hdim_qk, + hdim_v, + -1, // seqlen_q will be updated by another pointer + -1, // seqlen_kv will be updated by another pointer + num_head, + scale_s, + 1.0f, // to be set according to the per-group attn_scale and max_seqlen + 0, // to be set by the per-group contextual_seqlen + 0, // to be set by the per-group window_size + 0, // to be set by the per-group min_full_attn_seqlen + reinterpret_cast(group_max_seqlen_ptr), + reinterpret_cast(group_contextual_seqlen_ptr), + reinterpret_cast(group_window_size_ptr), + reinterpret_cast(group_min_full_attn_seqlen_ptr), + reinterpret_cast(group_attn_scale_ptr)}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dropout + }; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.seq_stride_bias = seq_stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, philox_seed, philox_offset); + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_, + ck_tile::index_t hdim_v_, + ck_tile::index_t num_splits, + bool has_minfull_attn_seqlen = false) + { + // The Q sequence [0, seqlen) will be split to two parts for allocating workgroups: + // 1) [0, seqlen - target - min_full_attn_seqlen) + // 2) [seqlen - target - min_full_attn_seqlen, seqlen) + ck_tile::index_t num_tile_in_seqlen = + ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0); + + if constexpr(kUseGroup) + { + num_tile_in_seqlen += 1; + } + else + { + if(has_minfull_attn_seqlen) + num_tile_in_seqlen += 1; + }; + + if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim) + { +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + return dim3(batch_size_, + nhead_, + num_tile_in_seqlen * num_splits * + ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1)); +#else + return dim3(num_tile_in_seqlen * num_splits * + ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1), + nhead_, + batch_size_); +#endif + } + else + { +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + return dim3(batch_size_, nhead_, num_tile_in_seqlen * num_splits); +#else + return dim3(num_tile_in_seqlen*num_splits), + nhead_, + batch_size_); +#endif + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim) + { + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1); + +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + const index_t i_batch = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_block = blockIdx.z; +#else + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; +#endif + +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + auto [i_tile_m_i_split, i_tile_n] = f(i_block, num_tile_n1); + auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits); + + i_tile_m = gridDim.z / num_tile_n1 / kargs.num_splits - 1 - i_tile_m; +#else + auto [i_tile_m_i_split, i_tile_n] = f(i_block, num_tile_n1); + auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits); +#endif + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch, i_split); + } + else + { +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + const index_t i_batch = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_block = blockIdx.z; +#else + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; +#endif + +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + index_t i_tile_m_i_split = i_block; + auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits); + i_tile_m = gridDim.z / kargs.num_splits - 1 - i_tile_m; +#else + index_t i_tile_m_i_split = i_block; + auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits); +#endif + const index_t i_tile_n = 0; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch, i_split); + } + } + + CK_TILE_DEVICE static constexpr auto + CalculateTileRangeAlongXForSplit(ck_tile::index_t global_seqlen_k_start, + ck_tile::index_t global_seqlen_k_end, + ck_tile::index_t num_splits, + ck_tile::index_t i_split) + { + ck_tile::index_t num_tile = ck_tile::integer_divide_ceil( + global_seqlen_k_end - global_seqlen_k_start, HstuAttentionPipeline::kN0); + + ck_tile::index_t num_tile_per_split = ck_tile::integer_divide_ceil(num_tile, num_splits); + + ck_tile::index_t my_seqlen_k_start = + global_seqlen_k_start + num_tile_per_split * i_split * HstuAttentionPipeline::kN0; + ck_tile::index_t my_seqlen_k_end = min( + global_seqlen_k_start + num_tile_per_split * (i_split + 1) * HstuAttentionPipeline::kN0, + global_seqlen_k_end); + + return ck_tile::make_tuple(my_seqlen_k_start, my_seqlen_k_end); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + const auto [i_tile_m, i_tile_n, i_nhead, i_batch, i_split] = GetTileIndex(kargs); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_o_acc = 0; + + if constexpr(kIsJagged) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch]; + const long_index_t key_start = kargs.seq_kv_offsets_ptr[i_batch]; + + batch_offset_q = query_start * kargs.seq_stride_q; + batch_offset_k = key_start * kargs.seq_stride_k; + batch_offset_v = key_start * kargs.seq_stride_v; + + if constexpr(kHasBias) + { + batch_offset_bias = query_start * kargs.seq_stride_bias; + } + + // assume o_acc is in compact shape of [batch_size, max_seqlen, num_head, num_splits, + // hdim] + batch_offset_o_acc = query_start * kargs.num_head * kargs.num_splits * kargs.hdim_v; + + kargs.seqlen_q = + kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch]; + kargs.seqlen_kv = + kargs.seq_kv_offsets_ptr[i_batch + 1] - kargs.seq_kv_offsets_ptr[i_batch]; + + // read from device memory for the group specific mask and scaling parameters + if constexpr(kUseGroup) + { + index_t i_group = + __builtin_amdgcn_readfirstlane(i_batch / kargs.num_batch_per_group); + + float attn_scale = kargs.group_attn_scale_ptr[i_group]; + index_t max_seqlen = kargs.group_max_seqlen_ptr[i_group]; + kargs.scale_p = (attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen)); + kargs.contextual_seqlen = kargs.group_contextual_seqlen_ptr[i_group]; + kargs.window_size = kargs.group_window_size_ptr[i_group]; + kargs.min_full_attn_seqlen = kargs.group_min_full_attn_seqlen_ptr[i_group]; + }; + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kHasBias) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + // assume o_acc is in compact shape of [batch_size, seqlen_q, num_head, num_splits, + // hdim] + batch_offset_o_acc = static_cast(i_batch) * kargs.seqlen_q * + kargs.num_head * kargs.num_splits * kargs.hdim_v; + } + + int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch]; + + index_t seqlen_in_first_split = kargs.seqlen_q; + bool is_tile_in_first_split = true; + index_t i_m0; + + if(kargs.min_full_attn_seqlen > 0) + { + // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len + if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen) + { + seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen; + + index_t num_tile_in_first_split = + __builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil( + seqlen_in_first_split, HstuAttentionPipeline::kM0)); + + is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + + i_m0 = is_tile_in_first_split + ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) + : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * + HstuAttentionPipeline::kM0) + + seqlen_in_first_split; + } + else + { + seqlen_in_first_split = 0; + is_tile_in_first_split = false; + + // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor + kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target; + + i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); + }; + } + else + i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); + + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1); + + index_t seqlen_q_in_ctrl = is_tile_in_first_split ? seqlen_in_first_split : kargs.seqlen_q; + + if(seqlen_q_in_ctrl <= i_m0) + return; + + // for simplicity, batch stride we just modify the pointer + const QKVDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const QKVDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_k + + batch_offset_k; + const QKVDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_v + + batch_offset_v; + // assume o_acc is in compact shape of [batch_size, seqlen, num_head, num_splits, hdim] + OaccDataType* o_acc_ptr = + reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.num_splits * kargs.hdim_v + + static_cast(i_split) * kargs.hdim_v + batch_offset_o_acc; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(seqlen_q_in_ctrl, kargs.hdim_qk), + make_tuple(kargs.seq_stride_q, 1), + number{}, + number<1>{}); + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_kv, kargs.hdim_qk), + make_tuple(kargs.seq_stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view(k_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_kv, kargs.hdim_v), + make_tuple(kargs.seq_stride_v, 1), + number{}, + number<1>{}); + + if constexpr(!kUseTrLoad) + { + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_kv)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view(v_dram_transposed, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(v_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }; + }(); + + auto q_dram_window = + make_tile_window(q_dram, + [&]() { + return make_tuple(number{}, + number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = + make_tile_window(k_dram, + make_tuple(number{}, + number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = make_tuple( + number{}, number{}); + if constexpr(kHasBias) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv), + make_tuple(kargs.seq_stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + bias_dram_naive, bias_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head, + kargs.drop_seed, + kargs.drop_offset, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + false}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto o_acc_tile = [&]() { + if(kargs.window_size > 0) + { + using HstuMaskType = typename ck_tile:: + HstuBlockMasking::Type; + + auto mask = [&]() { + if constexpr(kIsCrossAttention) + { + return make_hstu_cross_attention_block_mask_with_local( + is_tile_in_first_split, + kargs.seqlen_q, + kargs.seqlen_kv, + kargs.contextual_seqlen, + num_target, + kargs.window_size, + kargs.min_full_attn_seqlen); + } + else + { + return make_hstu_self_attention_block_mask_with_local( + is_tile_in_first_split, + kargs.seqlen_q, + kargs.contextual_seqlen, + num_target, + kargs.window_size, + kargs.min_full_attn_seqlen); + }; + }(); + + const auto [global_seqlen_k_start, global_seqlen_k_end] = + mask.GetTileRangeAlongX(i_m0, + number{}, + number{}); + + const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit( + global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split); + + return HstuAttentionPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + seqlen_k_start, + seqlen_k_end, + mask, + kargs.scale_s, + kargs.scale_p, + smem_ptr, + dropout); + } + else + { + using HstuMaskType = typename ck_tile:: + HstuBlockMasking::Type; + + auto mask = [&]() { + if constexpr(kIsCrossAttention) + { + return make_hstu_cross_attention_block_mask_without_local( + kargs.seqlen_q, kargs.seqlen_kv, kargs.contextual_seqlen, num_target); + } + else + { + return make_hstu_self_attention_block_mask_without_local( + kargs.seqlen_q, kargs.contextual_seqlen, num_target); + }; + }(); + + const auto [global_seqlen_k_start, global_seqlen_k_end] = + mask.GetTileRangeAlongX(i_m0, + number{}, + number{}); + + const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit( + global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split); + + return HstuAttentionPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + seqlen_k_start, + seqlen_k_end, + mask, + kargs.scale_s, + kargs.scale_p, + smem_ptr, + dropout); + } + }(); + + // Oacc DRAM and Oacc DRAM window + auto o_acc_dram = [&]() { + auto seq_stride_o_acc = kargs.num_head * kargs.num_splits * kargs.hdim_v; + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(seqlen_q_in_ctrl, kargs.hdim_v), + make_tuple(seq_stride_o_acc, 1), + number{}, + number<1>{}); + + return pad_tensor_view(o_acc_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + auto o_acc_dram_window = make_tile_window( + o_acc_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp index 58886d05b2..987e3cf198 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp @@ -22,6 +22,8 @@ #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" +#include "hstu_attention_group_forward_splitkv_dispatch.hpp" + template ::Run(param, stream); else - group_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + { + const bool disable_fwd_splitkv = []() { + const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV"); + if(env_p == nullptr) + return false; + return static_cast(atoi(env_p)); + }(); + + // ToDo: enable splitkv when kUseSoftmax is true + if(!disable_fwd_splitkv && !kUseSoftmax && + shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen)) + { + if constexpr(!kUseSoftmax) + { + group_forward_splitkv_causal_softmax_bias_dropout_dispatch::Run(param, stream); + }; + } + else + group_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + }; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp new file mode 100644 index 0000000000..b30e15dd4e --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_fwd_setting.hpp" +#include "hstu_attention_fwd_splitkv_combine_setting.hpp" +#include "hstu_attention_params.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_pipeline_problem.hpp" +#include "hstu_attention_traits.hpp" +#include "hstu_attention_with_softmax_fwd_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp" +#include "hstu_attention_fwd_splitkv_kernel.hpp" +#include "hstu_attention_fwd_splitkv_combine_kernel.hpp" + +template +struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch +{ + static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!"); + static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!"); + + using HstuAttentionFwdTileSetting = + typename std::conditional_t, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; + using HstuAttentionCombineTileSetting = + typename HstuAttentionFwdSplitKVCombineTileSetting::Type; + +#ifdef BUILD_HSTU_FOR_GFX95_ONLY + static constexpr bool kUseTrLoad = true; +#else + static constexpr bool kUseTrLoad = false; +#endif + + template + using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< + InOutDataType, + typename HstuAttentionFwdTypeConfig::GemmAccDataType, + typename HstuAttentionFwdTypeConfig::CompDataType, + typename HstuAttentionFwdTypeConfig::BiasDataType, + kIsCrossAttention, + true, // kUseGroup + true, // kIsJagged + kHasBias, + kHasDropout, + kUseCausal, + kUseSoftmax, + kUseTrLoad, + HstuAttentionFwdTileSetting>; + + using OaccDataType = HstuAttentionFwdTypeConfig::OaccDataType; + using ODataType = HstuAttentionFwdTypeConfig::ODataType; + + static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream) + { + constexpr ck_tile::index_t occupancy = -1; + + { + const bool pad_headdim_qk = + !(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + constexpr bool kPadSeqLenK = true; + + BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; + + using HstuEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + + BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { + using HstuPipelineProblem = HstuFwdPipelineProblemTemp; + + if constexpr(!kUseTrLoad) + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + else + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad< + HstuPipelineProblem, + HstuTraits>, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad< + HstuPipelineProblem, + HstuTraits>>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + }; + }); + }); + }; + { + using HstuCombinePipelineProblem = + ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem< + OaccDataType, + ODataType, + true /* kIsJagged */, + kUseSoftmax, + HstuAttentionCombineTileSetting>; + const bool pad_headdim_o = + !(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] { + using HstuTraits = ck_tile:: + HstuAttentionFwdSplitKVCombineTraits; + + using HstuEpilogue = + ck_tile::Default2DEpilogue>; + + using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline< + HstuCombinePipelineProblem, + HstuTraits>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVCombineKernel; + + RunWithFwdSplitKVCombineKernel(param, stream); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream) + { + param.num_splits = + get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen); + + // assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head, + // num_splits, hdim] + size_t workspace_bytes = static_cast(param.num_batch) * param.max_seqlen * + param.num_head * param.num_splits * param.hdim_v * + sizeof(OaccDataType); + + HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream)); + + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.q_ptr, + param.k_ptr, + param.v_ptr, + param.bias_ptr, + param.o_acc_ptr, + param.num_splits, + param.num_batch / param.num_group, + param.seq_q_offsets_ptr, + param.is_cross_attention ? param.seq_kv_offsets_ptr + : param.seq_q_offsets_ptr, + param.group_max_seqlen_ptr, + param.group_contextual_seqlen_ptr, + param.group_window_size_ptr, + param.group_min_full_attn_seqlen_ptr, + param.group_attn_scale_ptr, + param.hdim_qk, + param.hdim_v, + param.num_head, + param.scale_s, + param.seq_stride_q, + param.seq_stride_k, + param.seq_stride_v, + param.seq_stride_bias, + param.nhead_stride_q, + param.nhead_stride_k, + param.nhead_stride_v, + param.nhead_stride_bias, + param.num_targets_ptr, + param.p_drop, + param.philox_seed, + param.philox_offset); + }(); + + dim3 kGridSize = HstuKernel::GridSize( + param.num_batch, param.num_head, param.max_seqlen, param.hdim_v, param.num_splits); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithFwdSplitKVCombineKernel(HstuAttentionGroupFwdParams& param, + hipStream_t stream) + { + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.o_acc_ptr, + param.o_ptr, + param.seq_stride_o, + param.nhead_stride_o, + param.seq_q_offsets_ptr, + param.num_head, + param.num_splits, + param.hdim_v); + }(); + + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + + HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream)); + }; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index bf7dda05e0..717b414f82 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -22,6 +22,8 @@ #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" +#include "hstu_attention_jagged_forward_splitkv_dispatch.hpp" + template ::Run(param, stream); else - jagged_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + { + const bool disable_fwd_splitkv = []() { + const char* env_p = std::getenv("HSTU_DISABLE_SPLITKV"); + if(env_p == nullptr) + return false; + return static_cast(atoi(env_p)); + }(); + + // ToDo: enable splitkv when kUseSoftmax is true + if(!disable_fwd_splitkv && !kUseSoftmax && + shall_use_splitkv(param.num_batch, param.num_head, param.max_seqlen)) + { + if constexpr(!kUseSoftmax) + { + jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch::Run(param, stream); + }; + } + else + jagged_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + }; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp new file mode 100644 index 0000000000..1f88c42e38 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_fwd_setting.hpp" +#include "hstu_attention_fwd_splitkv_combine_setting.hpp" +#include "hstu_attention_params.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_pipeline_problem.hpp" +#include "hstu_attention_traits.hpp" +#include "hstu_attention_with_softmax_fwd_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_pipeline.hpp" +#include "hstu_attention_with_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_trload_pipeline.hpp" +#include "hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp" +#include "hstu_attention_fwd_splitkv_kernel.hpp" +#include "hstu_attention_fwd_splitkv_combine_kernel.hpp" + +template +struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch +{ + static_assert(kUseSoftmax == false, "Softmax support is not enabled yet!"); + static_assert(MTile == 64, "MTile must be 64 to get to fwd splitkv path!"); + + using HstuAttentionFwdTileSetting = + typename std::conditional_t, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; + using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting::Type; + +#ifdef BUILD_HSTU_FOR_GFX95_ONLY + static constexpr bool kUseTrLoad = true; +#else + static constexpr bool kUseTrLoad = false; +#endif + + template + using HstuFwdPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< + InOutDataType, + typename HstuAttentionFwdTypeConfig::GemmAccDataType, + typename HstuAttentionFwdTypeConfig::CompDataType, + typename HstuAttentionFwdTypeConfig::BiasDataType, + kIsCrossAttention, + false, // kUseGroup + true, // kIsJagged + kHasBias, + kHasDropout, + kUseCausal, + kUseSoftmax, + kUseTrLoad, + HstuAttentionFwdTileSetting>; + + using OaccDataType = HstuAttentionFwdTypeConfig::OaccDataType; + using ODataType = HstuAttentionFwdTypeConfig::ODataType; + + static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) + { + constexpr ck_tile::index_t occupancy = -1; + + { + const bool pad_headdim_qk = + !(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionFwdTileSetting::kN1 == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + constexpr bool kPadSeqLenK = true; + + BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; + + using HstuEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + + BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { + using HstuPipelineProblem = HstuFwdPipelineProblemTemp; + + if constexpr(!kUseTrLoad) + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + } + else + { + using HstuPipeline = std::conditional_t< + kUseSoftmax, + ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad< + HstuPipelineProblem, + HstuTraits>, + ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad< + HstuPipelineProblem, + HstuTraits>>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVKernel; + + RunWithFwdSplitKVKernel(param, stream); + }; + }); + }); + }; + { + using HstuCombinePipelineProblem = + ck_tile::HstuAttentionFwdSplitKVCombinePipelineProblem< + OaccDataType, + ODataType, + true /* kIsJagged */, + kUseSoftmax, + HstuAttentionCombineTileSetting>; + const bool pad_headdim_o = + !(param.hdim_v % HstuAttentionCombineTileSetting::kOHeaddim == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + BOOL_SWITCH(pad_headdim_o, kPadHeadDimO, [&] { + using HstuTraits = ck_tile:: + HstuAttentionFwdSplitKVCombineTraits; + + using HstuEpilogue = + ck_tile::Default2DEpilogue>; + + using HstuPipeline = ck_tile::HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline< + HstuCombinePipelineProblem, + HstuTraits>; + + using HstuKernel = + ck_tile::HstuAttentionFwdSplitKVCombineKernel; + + RunWithFwdSplitKVCombineKernel(param, stream); + }); + }; + }; + + template + static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) + { + param.num_splits = + get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen); + + // assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head, + // num_splits, hdim] + size_t workspace_bytes = static_cast(param.num_batch) * param.max_seqlen * + param.num_head * param.num_splits * param.hdim_v * + sizeof(OaccDataType); + + HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream)); + + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.q_ptr, + param.k_ptr, + param.v_ptr, + param.bias_ptr, + param.o_acc_ptr, + param.num_splits, + param.seq_q_offsets_ptr, + param.is_cross_attention ? param.seq_kv_offsets_ptr + : param.seq_q_offsets_ptr, + param.max_seqlen, + param.hdim_qk, + param.hdim_v, + param.num_head, + param.scale_s, + param.attn_scale, + param.seq_stride_q, + param.seq_stride_k, + param.seq_stride_v, + param.seq_stride_bias, + param.nhead_stride_q, + param.nhead_stride_k, + param.nhead_stride_v, + param.nhead_stride_bias, + param.num_targets_ptr, + param.contextual_seqlen, + param.window_size, + param.min_full_attn_seqlen, + param.p_drop, + param.philox_seed, + param.philox_offset); + }(); + + bool has_minfull_attn_seqlen = (param.min_full_attn_seqlen > 0); + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, + param.num_head, + param.max_seqlen, + param.hdim_v, + param.num_splits, + has_minfull_attn_seqlen); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; + + template + static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param, + hipStream_t stream) + { + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.o_acc_ptr, + param.o_ptr, + param.seq_stride_o, + param.nhead_stride_o, + param.seq_q_offsets_ptr, + param.num_head, + param.num_splits, + param.hdim_v); + }(); + + dim3 kGridSize = HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{stream, false}, + ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + + HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream)); + }; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index fd984a207d..c73f0ecffd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -180,6 +180,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; + if(seqlen_k_end <= seqlen_k_start) + { + clear_tile(o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + }; + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp new file mode 100644 index 0000000000..803fc6949f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp" + +namespace ck_tile { + +template +struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline +{ + using Problem = remove_cvref_t; + using Traits = remove_cvref_t; + using Policy = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM = Problem::kM; + static constexpr index_t kOHeaddim = Problem::kOHeaddim; + + static_assert(kOHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax"); + + static constexpr bool kIsJagged = Problem::kIsJagged; + + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimO = Traits::kPadHeadDimO; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentO = + kPadHeadDimO ? 1 : Policy::template GetAlignmentO(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Traits::kBlockPerCu != -1) + return Traits::kBlockPerCu; + else + { + return 2; + } + }(); + + static constexpr const char* name = "hstu_no_softmax_fwd_splitkv_combine"; + + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE auto + operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // M0*kOHeaddim tile + const OAccElementFunction& o_acc_element_func, + ck_tile::index_t split_stride, + ck_tile::index_t num_splits) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kM == OAccDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kOHeaddim == OAccDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + auto o_acc_dram_window = + make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + o_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakeOaccDramTileDistribution()); + + auto o_acc_ptr = o_acc_dram_window.get_bottom_tensor_view().get_buffer_view().p_data_; + + auto o_acc = load_tile(o_acc_dram_window); + + for(int i = 1; i < num_splits; i++) + { + o_acc_dram_window.set_bottom_tensor_view_data_ptr(o_acc_ptr + split_stride * i); + auto o_acc_tile = load_tile(o_acc_dram_window); + + tile_elementwise_inout([](auto& x, const auto& y) { x = x + y; }, o_acc, o_acc_tile); + }; + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_DEVICE auto + operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile + ck_tile::index_t split_stride, + ck_tile::index_t num_splits) const + { + return operator()(o_acc_dram_block_window_tmp, identity{}, split_stride, num_splits); + }; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 8d2bac5579..d2fc5806c8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -180,6 +180,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; + if(seqlen_k_end <= seqlen_k_start) + { + clear_tile(o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + }; + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 955d8c688c..03b8fc27e3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -65,6 +65,13 @@ struct HstuAttentionNoGroupFwdParams float p_drop; uint64_t philox_seed; uint64_t philox_offset; + + // this need not be set by the API users, we only use it for passing num_splits between splitkv + // and combine kernel + int num_splits; + // pointer of device memory allocated before calling fwd_splitkv kernel and released after + // calling combine kernel + void* o_acc_ptr; }; struct HstuAttentionGroupFwdParams @@ -125,4 +132,11 @@ struct HstuAttentionGroupFwdParams float p_drop; uint64_t philox_seed; uint64_t philox_offset; + + // this need not be set by the API users, it only use it for passing num_splits between splitkv + // and combine kernel + int num_splits; + // pointer of device memory allocated before calling fwd_splitkv kernel and released after + // calling combine kernel + void* o_acc_ptr; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 40fe064639..3dabf8acd3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -133,4 +133,42 @@ struct HstuAttentionFwdPipelineProblem }; }; +template +struct HstuAttentionFwdSplitKVCombinePipelineProblem +{ + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr bool kIsJagged = kIsJagged_; + static constexpr bool kUseSoftmax = kUseSoftmax_; + + static constexpr index_t kM = CombineTileSetting_::kM; + static constexpr index_t NumWarps = CombineTileSetting_::NumWarps; + static constexpr index_t kOHeaddim = CombineTileSetting_::kOHeaddim; + static constexpr index_t kSubOHeaddim = CombineTileSetting_::kSubOHeaddim; + static constexpr index_t kBlockSize = CombineTileSetting_::NumWarps * get_warp_size(); + + CK_TILE_HOST_DEVICE static constexpr auto GetOaccDramTileAccessMaxVectorSize() + { + constexpr index_t kMPerBlock = kM; + constexpr index_t kKPerBlock = kOHeaddim; + + return detail:: + GetDramTileAccessMaxVectorSize(); + }; + + CK_TILE_HOST_DEVICE static constexpr auto GetODramTileAccessMaxVectorSize() + { + constexpr index_t kMPerBlock = kM; + constexpr index_t kKPerBlock = kOHeaddim; + + return detail:: + GetDramTileAccessMaxVectorSize(); + }; +}; + } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp index efb9134edd..f3bacff35e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp @@ -64,4 +64,21 @@ struct HstuAttentionFwdTileSettingClass static_assert(kSubQKHeaddim % kN1 == 0, "Check failed!"); }; +template +struct HstuAttentionFwdSplitKVCombineTileSettingClass +{ + static constexpr index_t kM = kM_; + static constexpr index_t NumWarps = NumWarps_; + + static_assert(kM % NumWarps == 0, "Check failed!"); + + static constexpr index_t kOHeaddim = kOHeaddim_; + + static_assert((kM * kOHeaddim) % (NumWarps * get_warp_size()) == 0, "Check failed!"); + + static constexpr index_t kSubOHeaddim = ceil_to_qualified_tile_length(kOHeaddim); +}; + } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp index 5762e501b3..66aab00e6d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp @@ -22,4 +22,15 @@ struct HstuAttentionFwdTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct HstuAttentionFwdSplitKVCombineTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimO = kPadHeadDimO_; + + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile