mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Enable the kernel dispatching path from is_training & use_softmax to kStoreLSE
This commit is contained in:
@@ -647,6 +647,13 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
|
||||
lse_dev.FromDevice(lse_host.data());
|
||||
|
||||
if(dump_output)
|
||||
{
|
||||
dumpBufferToFile("lse_dev.dat", lse_host.data(), lse_host.get_element_space_size());
|
||||
dumpBufferToFile(
|
||||
"lse_host.dat", lse_host_ref.data(), lse_host.get_element_space_size());
|
||||
}
|
||||
|
||||
bool res_lse = ck_tile::check_err(
|
||||
lse_host, lse_host_ref, std::string("hstu_attention lse error"), rtol, atol);
|
||||
|
||||
@@ -1103,6 +1110,13 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
|
||||
lse_dev.FromDevice(lse_host.data());
|
||||
|
||||
if(dump_output)
|
||||
{
|
||||
dumpBufferToFile("lse_dev.dat", lse_host.data(), lse_host.get_element_space_size());
|
||||
dumpBufferToFile(
|
||||
"lse_host.dat", lse_host_ref.data(), lse_host.get_element_space_size());
|
||||
}
|
||||
|
||||
bool res_lse = ck_tile::check_err(
|
||||
lse_host, lse_host_ref, std::string("hstu_attention lse error"), rtol, atol);
|
||||
|
||||
|
||||
@@ -14,16 +14,21 @@ void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param, hipSt
|
||||
{
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
|
||||
bool store_lse = (param.use_softmax && param.is_training);
|
||||
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
false,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
|
||||
if constexpr(kUseSoftmax || !kStoreLSE)
|
||||
{
|
||||
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kStoreLSE,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -14,17 +14,21 @@ void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param, hipSt
|
||||
{
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
const bool store_lse = (param.use_softmax && param.is_training);
|
||||
|
||||
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
false,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
|
||||
if constexpr(kUseSoftmax || !kStoreLSE)
|
||||
{
|
||||
run_group_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kStoreLSE,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,25 +16,32 @@ void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param,
|
||||
{
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
|
||||
bool store_lse = (param.use_softmax && param.is_training);
|
||||
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
if(param.is_jagged)
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
false,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
else
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
false,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
|
||||
if constexpr(kUseSoftmax || !kStoreLSE)
|
||||
{
|
||||
if(param.is_jagged)
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kStoreLSE,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
else
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kStoreLSE,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,25 +16,32 @@ void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param,
|
||||
{
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] {
|
||||
bool store_lse = (param.use_softmax && param.is_training);
|
||||
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
if(param.is_jagged)
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
false,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
else
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
false,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(store_lse, kStoreLSE, [&] {
|
||||
if constexpr(kUseSoftmax || !kStoreLSE)
|
||||
{
|
||||
if(param.is_jagged)
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kStoreLSE,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
else
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kStoreLSE,
|
||||
kHasBias,
|
||||
false, // kHasDropout
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user