Enable the kernel dispatching path from is_training & use_softmax to kStoreLSE

This commit is contained in:
Qianfeng Zhang
2026-06-05 06:46:07 +00:00
parent 8b62d651a4
commit 798fd3cd8b
5 changed files with 91 additions and 54 deletions

View File

@@ -647,6 +647,13 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
lse_dev.FromDevice(lse_host.data());
if(dump_output)
{
dumpBufferToFile("lse_dev.dat", lse_host.data(), lse_host.get_element_space_size());
dumpBufferToFile(
"lse_host.dat", lse_host_ref.data(), lse_host.get_element_space_size());
}
bool res_lse = ck_tile::check_err(
lse_host, lse_host_ref, std::string("hstu_attention lse error"), rtol, atol);
@@ -1103,6 +1110,13 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
lse_dev.FromDevice(lse_host.data());
if(dump_output)
{
dumpBufferToFile("lse_dev.dat", lse_host.data(), lse_host.get_element_space_size());
dumpBufferToFile(
"lse_host.dat", lse_host_ref.data(), lse_host.get_element_space_size());
}
bool res_lse = ck_tile::check_err(
lse_host, lse_host_ref, std::string("hstu_attention lse error"), rtol, atol);

View File

@@ -14,16 +14,21 @@ void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param, hipSt
{
const bool has_bias = (param.bias_ptr != nullptr);
const bool use_causal = param.use_causal;
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
bool store_lse = (param.use_softmax && param.is_training);
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
false,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
if constexpr(kUseSoftmax || !kStoreLSE)
{
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kStoreLSE,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
}
});
});
});

View File

@@ -14,17 +14,21 @@ void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param, hipSt
{
const bool has_bias = (param.bias_ptr != nullptr);
const bool use_causal = param.use_causal;
const bool store_lse = (param.use_softmax && param.is_training);
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
false,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
if constexpr(kUseSoftmax || !kStoreLSE)
{
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kStoreLSE,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
}
});
});
});

View File

@@ -16,25 +16,32 @@ void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param,
{
const bool has_bias = (param.bias_ptr != nullptr);
const bool use_causal = param.use_causal;
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
bool store_lse = (param.use_softmax && param.is_training);
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
if(param.is_jagged)
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
false,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
else
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
false,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
if constexpr(kUseSoftmax || !kStoreLSE)
{
if(param.is_jagged)
run_jagged_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kStoreLSE,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
else
run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::bf16_t,
kUseCausal,
kUseSoftmax,
kStoreLSE,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
}
});
});
});

View File

@@ -16,25 +16,32 @@ void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param,
{
const bool has_bias = (param.bias_ptr != nullptr);
const bool use_causal = param.use_causal;
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
bool store_lse = (param.use_softmax && param.is_training);
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
if(param.is_jagged)
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
false,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
else
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
false,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
if constexpr(kUseSoftmax || !kStoreLSE)
{
if(param.is_jagged)
run_jagged_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kStoreLSE,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
else
run_batched_forward_causal_softmax_bias_dropout_dispatch<
ck_tile::fp16_t,
kUseCausal,
kUseSoftmax,
kStoreLSE,
kHasBias,
false, // kHasDropout
MaxK>(param, stream);
}
});
});
});