From 5ee8a37cd3dab0ff91fae9b99e8e4ed30449c449 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 2 Jun 2026 15:43:44 +0000 Subject: [PATCH] Move num_splits/o_acc_ptr/l_acc_ptr out from HstuAttentionFwdParams struct --- ...ntion_batched_forward_splitkv_dispatch.hpp | 51 +++++++++++-------- .../hstu_attention_fwd_setting.hpp | 34 ------------- ...tention_group_forward_splitkv_dispatch.hpp | 50 ++++++++++-------- ...ention_jagged_forward_splitkv_dispatch.hpp | 50 ++++++++++-------- .../hstu_attention_params.hpp | 16 ------ .../hstu_attention_splitkv_helper.hpp | 47 +++++++++++++++++ 6 files changed, 134 insertions(+), 114 deletions(-) create mode 100644 example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp index b5de76adb0..4d493bf028 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp @@ -26,6 +26,7 @@ #include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp" #include "hstu_attention_fwd_splitkv_kernel.hpp" #include "hstu_attention_fwd_splitkv_combine_kernel.hpp" +#include "hstu_attention_splitkv_helper.hpp" template ; - RunWithFwdSplitKVKernel(param, stream); + RunWithFwdSplitKVKernel(param, ws, stream); } else { @@ -150,7 +153,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVKernel; - RunWithFwdSplitKVKernel(param, stream); + RunWithFwdSplitKVKernel(param, ws, stream); }; }); }); @@ -170,7 +173,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access constexpr bool kPadSeqLenQ = false; - MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] { + MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] { constexpr ck_tile::index_t kMaxSplits = [&]() { if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize) return TMP_MAX_SPLITS; @@ -180,7 +183,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch return 4 * TMP_MAX_SPLITS; }(); - const bool pad_num_splits = (param.num_splits < kMaxSplits); + const bool pad_num_splits = (ws.num_splits < kMaxSplits); using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp; @@ -204,7 +207,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVCombineKernel; - RunWithFwdSplitKVCombineKernel(param, stream); + RunWithFwdSplitKVCombineKernel(param, ws, stream); }); }); } @@ -240,31 +243,34 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVCombineKernel; - RunWithFwdSplitKVCombineKernel(param, stream); + RunWithFwdSplitKVCombineKernel(param, ws, stream); }); } }; template - static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) + static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, + SplitkvWorkspace& ws, + hipStream_t stream) { - param.num_splits = - get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q); + ws.num_splits = get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q); // assume the workspace for o_acc is in compact shape of [num_batch, seqlen_q, num_head, // num_splits, hdim] size_t workspace_bytes = static_cast(param.num_batch) * param.seqlen_q * - param.num_head * param.num_splits * param.hdim_v * + param.num_head * ws.num_splits * param.hdim_v * sizeof(OaccDataType); - HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream)); + HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream)); if constexpr(kUseSoftmax) { + // assume the workspace for l_acc is in compact shape of [num_batch, seqlen_q, + // num_head, num_splits] workspace_bytes = static_cast(param.num_batch) * param.seqlen_q * - param.num_head * param.num_splits * sizeof(LSEDataType); + param.num_head * ws.num_splits * sizeof(LSEDataType); - HIP_CHECK_ERROR(hipMallocAsync(¶m.lse_acc_ptr, workspace_bytes, stream)); + HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream)); } const auto kargs = [&] { @@ -272,9 +278,9 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch param.k_ptr, param.v_ptr, param.bias_ptr, - param.o_acc_ptr, - param.lse_acc_ptr, - param.num_splits, + ws.o_acc_ptr, + ws.lse_acc_ptr, + ws.num_splits, param.seqlen_q, param.is_cross_attention ? param.seqlen_kv : param.seqlen_q, @@ -309,7 +315,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch param.num_head, param.seqlen_q, param.hdim_v, - param.num_splits, + ws.num_splits, has_minfull_attn_seqlen); dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; @@ -321,18 +327,19 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch template static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param, + SplitkvWorkspace& ws, hipStream_t stream) { const auto kargs = [&] { - return HstuKernel::MakeKargs(param.o_acc_ptr, - param.lse_acc_ptr, + return HstuKernel::MakeKargs(ws.o_acc_ptr, + ws.lse_acc_ptr, param.o_ptr, param.batch_stride_o, param.seq_stride_o, param.nhead_stride_o, param.seqlen_q, param.num_head, - param.num_splits, + ws.num_splits, param.hdim_v); }(); @@ -344,10 +351,10 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch ck_tile::stream_config{stream, false}, ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); - HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream)); + HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream)); if constexpr(kUseSoftmax) { - HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream)); + HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream)); } }; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 033c22cc8f..98c912fc5f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -484,37 +484,3 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_ // mtile-64 can be added through tuning/verification return 64; }; - -static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q) -{ - int num_CUs = get_number_of_cu(); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - - int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64); - - // assume each CU can run two work-groups, common cases for hdim128 - return static_cast(nbatch_nhead_mblocks) / (2.0f * num_CUs); -}; - -static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q) -{ - // Please tune the threshold here - const float threshold = 0.8f; - - if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold) - return true; - return false; -}; - -static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q) -{ - int i = 2; - - // Please tune the threshold here - const float threshold = 1.5f; - while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold) - i++; - - // the num_splits shall not be bigger than 64 - return ck_tile::min(i, 64); -}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp index b5ad2e59ae..35eda44a0d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp @@ -26,6 +26,7 @@ #include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp" #include "hstu_attention_fwd_splitkv_kernel.hpp" #include "hstu_attention_fwd_splitkv_combine_kernel.hpp" +#include "hstu_attention_splitkv_helper.hpp" template ; - RunWithFwdSplitKVKernel(param, stream); + RunWithFwdSplitKVKernel(param, ws, stream); } else { @@ -142,7 +145,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVKernel; - RunWithFwdSplitKVKernel(param, stream); + RunWithFwdSplitKVKernel(param, ws, stream); }; }); }); @@ -162,7 +165,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access constexpr bool kPadSeqLenQ = false; - MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] { + MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] { constexpr ck_tile::index_t kMaxSplits = [&]() { if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize) return TMP_MAX_SPLITS; @@ -172,7 +175,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch return 4 * TMP_MAX_SPLITS; }(); - const bool pad_num_splits = (param.num_splits < kMaxSplits); + const bool pad_num_splits = (ws.num_splits < kMaxSplits); using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp; @@ -196,7 +199,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVCombineKernel; - RunWithFwdSplitKVCombineKernel(param, stream); + RunWithFwdSplitKVCombineKernel(param, ws, stream); }); }); } @@ -232,31 +235,35 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVCombineKernel; - RunWithFwdSplitKVCombineKernel(param, stream); + RunWithFwdSplitKVCombineKernel(param, ws, stream); }); } }; template - static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream) + static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, + SplitkvWorkspace& ws, + hipStream_t stream) { - param.num_splits = + ws.num_splits = get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q); // assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head, // num_splits, hdim] size_t workspace_bytes = static_cast(param.num_batch) * param.max_seqlen_q * - param.num_head * param.num_splits * param.hdim_v * + param.num_head * ws.num_splits * param.hdim_v * sizeof(OaccDataType); - HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream)); + HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream)); if constexpr(kUseSoftmax) { + // assume the workspace for l_acc is in compact shape of [num_batch, max_seqlen, + // num_head, num_splits] workspace_bytes = static_cast(param.num_batch) * param.max_seqlen_q * - param.num_head * param.num_splits * sizeof(LSEDataType); + param.num_head * ws.num_splits * sizeof(LSEDataType); - HIP_CHECK_ERROR(hipMallocAsync(¶m.lse_acc_ptr, workspace_bytes, stream)); + HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream)); } const auto kargs = [&] { @@ -264,9 +271,9 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch param.k_ptr, param.v_ptr, param.bias_ptr, - param.o_acc_ptr, - param.lse_acc_ptr, - param.num_splits, + ws.o_acc_ptr, + ws.lse_acc_ptr, + ws.num_splits, param.num_batch / param.num_group, param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr @@ -295,7 +302,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch }(); dim3 kGridSize = HstuKernel::GridSize( - param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits); + param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, ws.num_splits); dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; @@ -306,17 +313,18 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch template static void RunWithFwdSplitKVCombineKernel(HstuAttentionGroupFwdParams& param, + SplitkvWorkspace& ws, hipStream_t stream) { const auto kargs = [&] { - return HstuKernel::MakeKargs(param.o_acc_ptr, - param.lse_acc_ptr, + return HstuKernel::MakeKargs(ws.o_acc_ptr, + ws.lse_acc_ptr, param.o_ptr, param.seq_stride_o, param.nhead_stride_o, param.seq_q_offsets_ptr, param.num_head, - param.num_splits, + ws.num_splits, param.hdim_v); }(); @@ -328,10 +336,10 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch ck_tile::stream_config{stream, false}, ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); - HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream)); + HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream)); if constexpr(kUseSoftmax) { - HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream)); + HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream)); } }; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp index 570b99dfca..57e5338a55 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp @@ -26,6 +26,7 @@ #include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp" #include "hstu_attention_fwd_splitkv_kernel.hpp" #include "hstu_attention_fwd_splitkv_combine_kernel.hpp" +#include "hstu_attention_splitkv_helper.hpp" template ; - RunWithFwdSplitKVKernel(param, stream); + RunWithFwdSplitKVKernel(param, ws, stream); } else { @@ -141,7 +144,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVKernel; - RunWithFwdSplitKVKernel(param, stream); + RunWithFwdSplitKVKernel(param, ws, stream); }; }); }); @@ -161,7 +164,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access constexpr bool kPadSeqLenQ = false; - MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] { + MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] { constexpr ck_tile::index_t kMaxSplits = [&]() { if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize) return TMP_MAX_SPLITS; @@ -171,7 +174,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch return 4 * TMP_MAX_SPLITS; }(); - const bool pad_num_splits = (param.num_splits < kMaxSplits); + const bool pad_num_splits = (ws.num_splits < kMaxSplits); using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp; @@ -195,7 +198,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVCombineKernel; - RunWithFwdSplitKVCombineKernel(param, stream); + RunWithFwdSplitKVCombineKernel(param, ws, stream); }); }); } @@ -231,31 +234,35 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuKernel = ck_tile::HstuAttentionFwdSplitKVCombineKernel; - RunWithFwdSplitKVCombineKernel(param, stream); + RunWithFwdSplitKVCombineKernel(param, ws, stream); }); }; }; template - static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) + static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, + SplitkvWorkspace& ws, + hipStream_t stream) { - param.num_splits = + ws.num_splits = get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q); // assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head, // num_splits, hdim] size_t workspace_bytes = static_cast(param.num_batch) * param.max_seqlen_q * - param.num_head * param.num_splits * param.hdim_v * + param.num_head * ws.num_splits * param.hdim_v * sizeof(OaccDataType); - HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream)); + HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream)); if constexpr(kUseSoftmax) { + // assume the workspace for l_acc is in compact shape of [num_batch, max_seqlen, + // num_head, num_splits] workspace_bytes = static_cast(param.num_batch) * param.max_seqlen_q * - param.num_head * param.num_splits * sizeof(LSEDataType); + param.num_head * ws.num_splits * sizeof(LSEDataType); - HIP_CHECK_ERROR(hipMallocAsync(¶m.lse_acc_ptr, workspace_bytes, stream)); + HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream)); } const auto kargs = [&] { @@ -263,9 +270,9 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch param.k_ptr, param.v_ptr, param.bias_ptr, - param.o_acc_ptr, - param.lse_acc_ptr, - param.num_splits, + ws.o_acc_ptr, + ws.lse_acc_ptr, + ws.num_splits, param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr : param.seq_q_offsets_ptr, @@ -297,7 +304,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch param.num_head, param.max_seqlen_q, param.hdim_v, - param.num_splits, + ws.num_splits, has_minfull_attn_seqlen); dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; @@ -309,17 +316,18 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch template static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param, + SplitkvWorkspace& ws, hipStream_t stream) { const auto kargs = [&] { - return HstuKernel::MakeKargs(param.o_acc_ptr, - param.lse_acc_ptr, + return HstuKernel::MakeKargs(ws.o_acc_ptr, + ws.lse_acc_ptr, param.o_ptr, param.seq_stride_o, param.nhead_stride_o, param.seq_q_offsets_ptr, param.num_head, - param.num_splits, + ws.num_splits, param.hdim_v); }(); @@ -331,10 +339,10 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch ck_tile::stream_config{stream, false}, ck_tile::make_kernel(HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); - HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream)); + HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream)); if constexpr(kUseSoftmax) { - HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream)); + HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream)); } }; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 592324b7d9..cd43e74ce5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -65,14 +65,6 @@ struct HstuAttentionNoGroupFwdParams float p_drop; uint64_t philox_seed; uint64_t philox_offset; - - // this need not be set by the API users, we only use it for passing num_splits between splitkv - // and combine kernel - int num_splits; - // pointer of device memory allocated before calling fwd_splitkv kernel and released after - // calling combine kernel - void* o_acc_ptr; - void* lse_acc_ptr; }; struct HstuAttentionGroupFwdParams @@ -133,12 +125,4 @@ struct HstuAttentionGroupFwdParams float p_drop; uint64_t philox_seed; uint64_t philox_offset; - - // this need not be set by the API users, it only use it for passing num_splits between splitkv - // and combine kernel - int num_splits; - // pointer of device memory allocated before calling fwd_splitkv kernel and released after - // calling combine kernel - void* o_acc_ptr; - void* lse_acc_ptr; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp new file mode 100644 index 0000000000..728dcb3095 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_splitkv_helper.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "hstu_attention_util.hpp" + +static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q) +{ + int num_CUs = get_number_of_cu(); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64); + + // assume each CU can run two work-groups, common cases for hdim128 + return static_cast(nbatch_nhead_mblocks) / (2.0f * num_CUs); +}; + +static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q) +{ + // Please tune the threshold here + const float threshold = 0.8f; + + if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold) + return true; + return false; +}; + +static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q) +{ + int i = 2; + + // Please tune the threshold here + const float threshold = 1.5f; + while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold) + i++; + + // the num_splits shall not be bigger than 64 + return ck_tile::min(i, 64); +}; + +struct SplitkvWorkspace +{ + int num_splits; + void* o_acc_ptr; + void* lse_acc_ptr; // only used when softmax is used +};