Move num_splits/o_acc_ptr/l_acc_ptr out from HstuAttention<xxx>FwdParams struct

This commit is contained in:
Qianfeng Zhang
2026-06-02 15:43:44 +00:00
parent 36dd77fb16
commit 5ee8a37cd3
6 changed files with 134 additions and 114 deletions

View File

@@ -26,6 +26,7 @@
#include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp"
#include "hstu_attention_fwd_splitkv_kernel.hpp"
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
#include "hstu_attention_splitkv_helper.hpp"
template <typename InOutDataType,
bool kUseCausal,
@@ -85,6 +86,8 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
{
constexpr ck_tile::index_t occupancy = -1;
SplitkvWorkspace ws;
{
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionFwdTileSetting::kN0 == 0);
const bool pad_headdim_qk =
@@ -134,7 +137,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
}
else
{
@@ -150,7 +153,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
};
});
});
@@ -170,7 +173,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] {
MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] {
constexpr ck_tile::index_t kMaxSplits = [&]() {
if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize)
return TMP_MAX_SPLITS;
@@ -180,7 +183,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
return 4 * TMP_MAX_SPLITS;
}();
const bool pad_num_splits = (param.num_splits < kMaxSplits);
const bool pad_num_splits = (ws.num_splits < kMaxSplits);
using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp<kMaxSplits>;
@@ -204,7 +207,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
});
});
}
@@ -240,31 +243,34 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
});
}
};
template <typename HstuKernel>
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param,
SplitkvWorkspace& ws,
hipStream_t stream)
{
param.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q);
ws.num_splits = get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q);
// assume the workspace for o_acc is in compact shape of [num_batch, seqlen_q, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.seqlen_q *
param.num_head * param.num_splits * param.hdim_v *
param.num_head * ws.num_splits * param.hdim_v *
sizeof(OaccDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.o_acc_ptr, workspace_bytes, stream));
HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream));
if constexpr(kUseSoftmax)
{
// assume the workspace for l_acc is in compact shape of [num_batch, seqlen_q,
// num_head, num_splits]
workspace_bytes = static_cast<size_t>(param.num_batch) * param.seqlen_q *
param.num_head * param.num_splits * sizeof(LSEDataType);
param.num_head * ws.num_splits * sizeof(LSEDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.lse_acc_ptr, workspace_bytes, stream));
HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream));
}
const auto kargs = [&] {
@@ -272,9 +278,9 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_acc_ptr,
param.lse_acc_ptr,
param.num_splits,
ws.o_acc_ptr,
ws.lse_acc_ptr,
ws.num_splits,
param.seqlen_q,
param.is_cross_attention ? param.seqlen_kv
: param.seqlen_q,
@@ -309,7 +315,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.num_head,
param.seqlen_q,
param.hdim_v,
param.num_splits,
ws.num_splits,
has_minfull_attn_seqlen);
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
@@ -321,18 +327,19 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
template <typename HstuKernel>
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
SplitkvWorkspace& ws,
hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.o_acc_ptr,
param.lse_acc_ptr,
return HstuKernel::MakeKargs(ws.o_acc_ptr,
ws.lse_acc_ptr,
param.o_ptr,
param.batch_stride_o,
param.seq_stride_o,
param.nhead_stride_o,
param.seqlen_q,
param.num_head,
param.num_splits,
ws.num_splits,
param.hdim_v);
}();
@@ -344,10 +351,10 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream));
if constexpr(kUseSoftmax)
{
HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream));
HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream));
}
};
};

View File

@@ -484,37 +484,3 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_
// mtile-64 can be added through tuning/verification
return 64;
};
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
{
int num_CUs = get_number_of_cu();
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64);
// assume each CU can run two work-groups, common cases for hdim128
return static_cast<float>(nbatch_nhead_mblocks) / (2.0f * num_CUs);
};
static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q)
{
// Please tune the threshold here
const float threshold = 0.8f;
if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold)
return true;
return false;
};
static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q)
{
int i = 2;
// Please tune the threshold here
const float threshold = 1.5f;
while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold)
i++;
// the num_splits shall not be bigger than 64
return ck_tile::min(i, 64);
};

View File

@@ -26,6 +26,7 @@
#include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp"
#include "hstu_attention_fwd_splitkv_kernel.hpp"
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
#include "hstu_attention_splitkv_helper.hpp"
template <typename InOutDataType,
bool kUseCausal,
@@ -86,6 +87,8 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
{
constexpr ck_tile::index_t occupancy = -1;
SplitkvWorkspace ws;
{
const bool pad_headdim_qk =
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
@@ -126,7 +129,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
}
else
{
@@ -142,7 +145,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
};
});
});
@@ -162,7 +165,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] {
MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] {
constexpr ck_tile::index_t kMaxSplits = [&]() {
if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize)
return TMP_MAX_SPLITS;
@@ -172,7 +175,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
return 4 * TMP_MAX_SPLITS;
}();
const bool pad_num_splits = (param.num_splits < kMaxSplits);
const bool pad_num_splits = (ws.num_splits < kMaxSplits);
using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp<kMaxSplits>;
@@ -196,7 +199,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
});
});
}
@@ -232,31 +235,35 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
});
}
};
template <typename HstuKernel>
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param,
SplitkvWorkspace& ws,
hipStream_t stream)
{
param.num_splits =
ws.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
param.num_head * param.num_splits * param.hdim_v *
param.num_head * ws.num_splits * param.hdim_v *
sizeof(OaccDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.o_acc_ptr, workspace_bytes, stream));
HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream));
if constexpr(kUseSoftmax)
{
// assume the workspace for l_acc is in compact shape of [num_batch, max_seqlen,
// num_head, num_splits]
workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
param.num_head * param.num_splits * sizeof(LSEDataType);
param.num_head * ws.num_splits * sizeof(LSEDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.lse_acc_ptr, workspace_bytes, stream));
HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream));
}
const auto kargs = [&] {
@@ -264,9 +271,9 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_acc_ptr,
param.lse_acc_ptr,
param.num_splits,
ws.o_acc_ptr,
ws.lse_acc_ptr,
ws.num_splits,
param.num_batch / param.num_group,
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
@@ -295,7 +302,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
}();
dim3 kGridSize = HstuKernel::GridSize(
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits);
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, ws.num_splits);
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
@@ -306,17 +313,18 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
template <typename HstuKernel>
static void RunWithFwdSplitKVCombineKernel(HstuAttentionGroupFwdParams& param,
SplitkvWorkspace& ws,
hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.o_acc_ptr,
param.lse_acc_ptr,
return HstuKernel::MakeKargs(ws.o_acc_ptr,
ws.lse_acc_ptr,
param.o_ptr,
param.seq_stride_o,
param.nhead_stride_o,
param.seq_q_offsets_ptr,
param.num_head,
param.num_splits,
ws.num_splits,
param.hdim_v);
}();
@@ -328,10 +336,10 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream));
if constexpr(kUseSoftmax)
{
HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream));
HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream));
}
};
};

View File

@@ -26,6 +26,7 @@
#include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp"
#include "hstu_attention_fwd_splitkv_kernel.hpp"
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
#include "hstu_attention_splitkv_helper.hpp"
template <typename InOutDataType,
bool kUseCausal,
@@ -85,6 +86,8 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
{
constexpr ck_tile::index_t occupancy = -1;
SplitkvWorkspace ws;
{
const bool pad_headdim_qk =
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
@@ -125,7 +128,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
}
else
{
@@ -141,7 +144,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
};
});
});
@@ -161,7 +164,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
constexpr bool kPadSeqLenQ = false;
MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] {
MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] {
constexpr ck_tile::index_t kMaxSplits = [&]() {
if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize)
return TMP_MAX_SPLITS;
@@ -171,7 +174,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
return 4 * TMP_MAX_SPLITS;
}();
const bool pad_num_splits = (param.num_splits < kMaxSplits);
const bool pad_num_splits = (ws.num_splits < kMaxSplits);
using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp<kMaxSplits>;
@@ -195,7 +198,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
});
});
}
@@ -231,31 +234,35 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuKernel =
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
});
};
};
template <typename HstuKernel>
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param,
SplitkvWorkspace& ws,
hipStream_t stream)
{
param.num_splits =
ws.num_splits =
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
// num_splits, hdim]
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
param.num_head * param.num_splits * param.hdim_v *
param.num_head * ws.num_splits * param.hdim_v *
sizeof(OaccDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.o_acc_ptr, workspace_bytes, stream));
HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream));
if constexpr(kUseSoftmax)
{
// assume the workspace for l_acc is in compact shape of [num_batch, max_seqlen,
// num_head, num_splits]
workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
param.num_head * param.num_splits * sizeof(LSEDataType);
param.num_head * ws.num_splits * sizeof(LSEDataType);
HIP_CHECK_ERROR(hipMallocAsync(&param.lse_acc_ptr, workspace_bytes, stream));
HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream));
}
const auto kargs = [&] {
@@ -263,9 +270,9 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_acc_ptr,
param.lse_acc_ptr,
param.num_splits,
ws.o_acc_ptr,
ws.lse_acc_ptr,
ws.num_splits,
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
@@ -297,7 +304,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
param.num_head,
param.max_seqlen_q,
param.hdim_v,
param.num_splits,
ws.num_splits,
has_minfull_attn_seqlen);
dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
@@ -309,17 +316,18 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
template <typename HstuKernel>
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
SplitkvWorkspace& ws,
hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.o_acc_ptr,
param.lse_acc_ptr,
return HstuKernel::MakeKargs(ws.o_acc_ptr,
ws.lse_acc_ptr,
param.o_ptr,
param.seq_stride_o,
param.nhead_stride_o,
param.seq_q_offsets_ptr,
param.num_head,
param.num_splits,
ws.num_splits,
param.hdim_v);
}();
@@ -331,10 +339,10 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream));
if constexpr(kUseSoftmax)
{
HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream));
HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream));
}
};
};

View File

@@ -65,14 +65,6 @@ struct HstuAttentionNoGroupFwdParams
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;
// this need not be set by the API users, we only use it for passing num_splits between splitkv
// and combine kernel
int num_splits;
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
// calling combine kernel
void* o_acc_ptr;
void* lse_acc_ptr;
};
struct HstuAttentionGroupFwdParams
@@ -133,12 +125,4 @@ struct HstuAttentionGroupFwdParams
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;
// this need not be set by the API users, it only use it for passing num_splits between splitkv
// and combine kernel
int num_splits;
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
// calling combine kernel
void* o_acc_ptr;
void* lse_acc_ptr;
};

View File

@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "hstu_attention_util.hpp"
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
{
int num_CUs = get_number_of_cu();
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64);
// assume each CU can run two work-groups, common cases for hdim128
return static_cast<float>(nbatch_nhead_mblocks) / (2.0f * num_CUs);
};
static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q)
{
// Please tune the threshold here
const float threshold = 0.8f;
if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold)
return true;
return false;
};
static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q)
{
int i = 2;
// Please tune the threshold here
const float threshold = 1.5f;
while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold)
i++;
// the num_splits shall not be bigger than 64
return ck_tile::min(i, 64);
};
struct SplitkvWorkspace
{
int num_splits;
void* o_acc_ptr;
void* lse_acc_ptr; // only used when softmax is used
};