mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Move num_splits/o_acc_ptr/l_acc_ptr out from HstuAttention<xxx>FwdParams struct
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_kernel.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "hstu_attention_splitkv_helper.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
@@ -85,6 +86,8 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
SplitkvWorkspace ws;
|
||||
|
||||
{
|
||||
const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionFwdTileSetting::kN0 == 0);
|
||||
const bool pad_headdim_qk =
|
||||
@@ -134,7 +137,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -150,7 +153,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
@@ -170,7 +173,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] {
|
||||
MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] {
|
||||
constexpr ck_tile::index_t kMaxSplits = [&]() {
|
||||
if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize)
|
||||
return TMP_MAX_SPLITS;
|
||||
@@ -180,7 +183,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
return 4 * TMP_MAX_SPLITS;
|
||||
}();
|
||||
|
||||
const bool pad_num_splits = (param.num_splits < kMaxSplits);
|
||||
const bool pad_num_splits = (ws.num_splits < kMaxSplits);
|
||||
|
||||
using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp<kMaxSplits>;
|
||||
|
||||
@@ -204,7 +207,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -240,31 +243,34 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param,
|
||||
SplitkvWorkspace& ws,
|
||||
hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q);
|
||||
ws.num_splits = get_suggested_num_splits(param.num_batch, param.num_head, param.seqlen_q);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, seqlen_q, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.seqlen_q *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
param.num_head * ws.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream));
|
||||
HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream));
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
// assume the workspace for l_acc is in compact shape of [num_batch, seqlen_q,
|
||||
// num_head, num_splits]
|
||||
workspace_bytes = static_cast<size_t>(param.num_batch) * param.seqlen_q *
|
||||
param.num_head * param.num_splits * sizeof(LSEDataType);
|
||||
param.num_head * ws.num_splits * sizeof(LSEDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.lse_acc_ptr, workspace_bytes, stream));
|
||||
HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream));
|
||||
}
|
||||
|
||||
const auto kargs = [&] {
|
||||
@@ -272,9 +278,9 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_acc_ptr,
|
||||
param.lse_acc_ptr,
|
||||
param.num_splits,
|
||||
ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
ws.num_splits,
|
||||
param.seqlen_q,
|
||||
param.is_cross_attention ? param.seqlen_kv
|
||||
: param.seqlen_q,
|
||||
@@ -309,7 +315,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.num_head,
|
||||
param.seqlen_q,
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
ws.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
@@ -321,18 +327,19 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
|
||||
SplitkvWorkspace& ws,
|
||||
hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.o_acc_ptr,
|
||||
param.lse_acc_ptr,
|
||||
return HstuKernel::MakeKargs(ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
param.o_ptr,
|
||||
param.batch_stride_o,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_o,
|
||||
param.seqlen_q,
|
||||
param.num_head,
|
||||
param.num_splits,
|
||||
ws.num_splits,
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
@@ -344,10 +351,10 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
|
||||
HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream));
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream));
|
||||
HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -484,37 +484,3 @@ static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_
|
||||
// mtile-64 can be added through tuning/verification
|
||||
return 64;
|
||||
};
|
||||
|
||||
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
int num_CUs = get_number_of_cu();
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
|
||||
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64);
|
||||
|
||||
// assume each CU can run two work-groups, common cases for hdim128
|
||||
return static_cast<float>(nbatch_nhead_mblocks) / (2.0f * num_CUs);
|
||||
};
|
||||
|
||||
static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
// Please tune the threshold here
|
||||
const float threshold = 0.8f;
|
||||
|
||||
if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold)
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
int i = 2;
|
||||
|
||||
// Please tune the threshold here
|
||||
const float threshold = 1.5f;
|
||||
while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold)
|
||||
i++;
|
||||
|
||||
// the num_splits shall not be bigger than 64
|
||||
return ck_tile::min(i, 64);
|
||||
};
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_kernel.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "hstu_attention_splitkv_helper.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
@@ -86,6 +87,8 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
SplitkvWorkspace ws;
|
||||
|
||||
{
|
||||
const bool pad_headdim_qk =
|
||||
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
|
||||
@@ -126,7 +129,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -142,7 +145,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
@@ -162,7 +165,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] {
|
||||
MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] {
|
||||
constexpr ck_tile::index_t kMaxSplits = [&]() {
|
||||
if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize)
|
||||
return TMP_MAX_SPLITS;
|
||||
@@ -172,7 +175,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
return 4 * TMP_MAX_SPLITS;
|
||||
}();
|
||||
|
||||
const bool pad_num_splits = (param.num_splits < kMaxSplits);
|
||||
const bool pad_num_splits = (ws.num_splits < kMaxSplits);
|
||||
|
||||
using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp<kMaxSplits>;
|
||||
|
||||
@@ -196,7 +199,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -232,31 +235,35 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionGroupFwdParams& param,
|
||||
SplitkvWorkspace& ws,
|
||||
hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
ws.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
param.num_head * ws.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream));
|
||||
HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream));
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
// assume the workspace for l_acc is in compact shape of [num_batch, max_seqlen,
|
||||
// num_head, num_splits]
|
||||
workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
|
||||
param.num_head * param.num_splits * sizeof(LSEDataType);
|
||||
param.num_head * ws.num_splits * sizeof(LSEDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.lse_acc_ptr, workspace_bytes, stream));
|
||||
HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream));
|
||||
}
|
||||
|
||||
const auto kargs = [&] {
|
||||
@@ -264,9 +271,9 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_acc_ptr,
|
||||
param.lse_acc_ptr,
|
||||
param.num_splits,
|
||||
ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
ws.num_splits,
|
||||
param.num_batch / param.num_group,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
@@ -295,7 +302,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
}();
|
||||
|
||||
dim3 kGridSize = HstuKernel::GridSize(
|
||||
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, param.num_splits);
|
||||
param.num_batch, param.num_head, param.max_seqlen_q, param.hdim_v, ws.num_splits);
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
@@ -306,17 +313,18 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVCombineKernel(HstuAttentionGroupFwdParams& param,
|
||||
SplitkvWorkspace& ws,
|
||||
hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.o_acc_ptr,
|
||||
param.lse_acc_ptr,
|
||||
return HstuKernel::MakeKargs(ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_o,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.num_head,
|
||||
param.num_splits,
|
||||
ws.num_splits,
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
@@ -328,10 +336,10 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
|
||||
HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream));
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream));
|
||||
HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_kernel.hpp"
|
||||
#include "hstu_attention_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "hstu_attention_splitkv_helper.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
@@ -85,6 +86,8 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
SplitkvWorkspace ws;
|
||||
|
||||
{
|
||||
const bool pad_headdim_qk =
|
||||
!(param.hdim_qk % HstuAttentionFwdTileSetting::kQKHeaddim == 0);
|
||||
@@ -125,7 +128,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -141,7 +144,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVKernel<HstuKernel>(param, ws, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
@@ -161,7 +164,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
MAX_SPLITS_SWITCH(param.num_splits, TMP_MAX_SPLITS, [&] {
|
||||
MAX_SPLITS_SWITCH(ws.num_splits, TMP_MAX_SPLITS, [&] {
|
||||
constexpr ck_tile::index_t kMaxSplits = [&]() {
|
||||
if constexpr(kM * TMP_MAX_SPLITS >= kBlockSize)
|
||||
return TMP_MAX_SPLITS;
|
||||
@@ -171,7 +174,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
return 4 * TMP_MAX_SPLITS;
|
||||
}();
|
||||
|
||||
const bool pad_num_splits = (param.num_splits < kMaxSplits);
|
||||
const bool pad_num_splits = (ws.num_splits < kMaxSplits);
|
||||
|
||||
using HstuCombinePipelineProblem = HstuCombinePipelineProblemTemp<kMaxSplits>;
|
||||
|
||||
@@ -195,7 +198,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -231,31 +234,35 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuKernel =
|
||||
ck_tile::HstuAttentionFwdSplitKVCombineKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, stream);
|
||||
RunWithFwdSplitKVCombineKernel<HstuKernel>(param, ws, stream);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
static void RunWithFwdSplitKVKernel(HstuAttentionNoGroupFwdParams& param,
|
||||
SplitkvWorkspace& ws,
|
||||
hipStream_t stream)
|
||||
{
|
||||
param.num_splits =
|
||||
ws.num_splits =
|
||||
get_suggested_num_splits(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
|
||||
// assume the workspace for o_acc is in compact shape of [num_batch, max_seqlen, num_head,
|
||||
// num_splits, hdim]
|
||||
size_t workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
|
||||
param.num_head * param.num_splits * param.hdim_v *
|
||||
param.num_head * ws.num_splits * param.hdim_v *
|
||||
sizeof(OaccDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.o_acc_ptr, workspace_bytes, stream));
|
||||
HIP_CHECK_ERROR(hipMallocAsync(&ws.o_acc_ptr, workspace_bytes, stream));
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
// assume the workspace for l_acc is in compact shape of [num_batch, max_seqlen,
|
||||
// num_head, num_splits]
|
||||
workspace_bytes = static_cast<size_t>(param.num_batch) * param.max_seqlen_q *
|
||||
param.num_head * param.num_splits * sizeof(LSEDataType);
|
||||
param.num_head * ws.num_splits * sizeof(LSEDataType);
|
||||
|
||||
HIP_CHECK_ERROR(hipMallocAsync(¶m.lse_acc_ptr, workspace_bytes, stream));
|
||||
HIP_CHECK_ERROR(hipMallocAsync(&ws.lse_acc_ptr, workspace_bytes, stream));
|
||||
}
|
||||
|
||||
const auto kargs = [&] {
|
||||
@@ -263,9 +270,9 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_acc_ptr,
|
||||
param.lse_acc_ptr,
|
||||
param.num_splits,
|
||||
ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
ws.num_splits,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
@@ -297,7 +304,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
param.num_head,
|
||||
param.max_seqlen_q,
|
||||
param.hdim_v,
|
||||
param.num_splits,
|
||||
ws.num_splits,
|
||||
has_minfull_attn_seqlen);
|
||||
dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
@@ -309,17 +316,18 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithFwdSplitKVCombineKernel(HstuAttentionNoGroupFwdParams& param,
|
||||
SplitkvWorkspace& ws,
|
||||
hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.o_acc_ptr,
|
||||
param.lse_acc_ptr,
|
||||
return HstuKernel::MakeKargs(ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_o,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.num_head,
|
||||
param.num_splits,
|
||||
ws.num_splits,
|
||||
param.hdim_v);
|
||||
}();
|
||||
|
||||
@@ -331,10 +339,10 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.o_acc_ptr, stream));
|
||||
HIP_CHECK_ERROR(hipFreeAsync(ws.o_acc_ptr, stream));
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFreeAsync(param.lse_acc_ptr, stream));
|
||||
HIP_CHECK_ERROR(hipFreeAsync(ws.lse_acc_ptr, stream));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -65,14 +65,6 @@ struct HstuAttentionNoGroupFwdParams
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
// this need not be set by the API users, we only use it for passing num_splits between splitkv
|
||||
// and combine kernel
|
||||
int num_splits;
|
||||
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
|
||||
// calling combine kernel
|
||||
void* o_acc_ptr;
|
||||
void* lse_acc_ptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionGroupFwdParams
|
||||
@@ -133,12 +125,4 @@ struct HstuAttentionGroupFwdParams
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
// this need not be set by the API users, it only use it for passing num_splits between splitkv
|
||||
// and combine kernel
|
||||
int num_splits;
|
||||
// pointer of device memory allocated before calling fwd_splitkv kernel and released after
|
||||
// calling combine kernel
|
||||
void* o_acc_ptr;
|
||||
void* lse_acc_ptr;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "hstu_attention_util.hpp"
|
||||
|
||||
static float get_estimated_cu_coverage_ratio(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
int num_CUs = get_number_of_cu();
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
|
||||
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 64);
|
||||
|
||||
// assume each CU can run two work-groups, common cases for hdim128
|
||||
return static_cast<float>(nbatch_nhead_mblocks) / (2.0f * num_CUs);
|
||||
};
|
||||
|
||||
static bool shall_use_splitkv(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
// Please tune the threshold here
|
||||
const float threshold = 0.8f;
|
||||
|
||||
if(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) < threshold)
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
static int get_suggested_num_splits(int num_batches, int num_heads, int max_seqlen_q)
|
||||
{
|
||||
int i = 2;
|
||||
|
||||
// Please tune the threshold here
|
||||
const float threshold = 1.5f;
|
||||
while(get_estimated_cu_coverage_ratio(num_batches, num_heads, max_seqlen_q) * i < threshold)
|
||||
i++;
|
||||
|
||||
// the num_splits shall not be bigger than 64
|
||||
return ck_tile::min(i, 64);
|
||||
};
|
||||
|
||||
struct SplitkvWorkspace
|
||||
{
|
||||
int num_splits;
|
||||
void* o_acc_ptr;
|
||||
void* lse_acc_ptr; // only used when softmax is used
|
||||
};
|
||||
Reference in New Issue
Block a user