diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp index 6667fe5dc0..b346f423b9 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp @@ -647,6 +647,13 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged) lse_dev.FromDevice(lse_host.data()); + if(dump_output) + { + dumpBufferToFile("lse_dev.dat", lse_host.data(), lse_host.get_element_space_size()); + dumpBufferToFile( + "lse_host.dat", lse_host_ref.data(), lse_host.get_element_space_size()); + } + bool res_lse = ck_tile::check_err( lse_host, lse_host_ref, std::string("hstu_attention lse error"), rtol, atol); @@ -1103,6 +1110,13 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) lse_dev.FromDevice(lse_host.data()); + if(dump_output) + { + dumpBufferToFile("lse_dev.dat", lse_host.data(), lse_host.get_element_space_size()); + dumpBufferToFile( + "lse_host.dat", lse_host_ref.data(), lse_host.get_element_space_size()); + } + bool res_lse = ck_tile::check_err( lse_host, lse_host_ref, std::string("hstu_attention lse error"), rtol, atol); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_bf16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_bf16.cpp index 7b6222158b..f7170f81e0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_bf16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_bf16.cpp @@ -14,16 +14,21 @@ void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param, hipSt { const bool has_bias = (param.bias_ptr != nullptr); const bool use_causal = param.use_causal; - BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] { + bool store_lse = (param.use_softmax && param.is_training); + + BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { - run_group_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(store_lse, kStoreLSE, [&] { + if constexpr(kUseSoftmax || !kStoreLSE) + { + run_group_forward_causal_softmax_bias_dropout_dispatch(param, stream); + } }); }); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_fp16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_fp16.cpp index 72d25ae8f2..3cbe594fda 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_fp16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_fp16.cpp @@ -14,17 +14,21 @@ void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param, hipSt { const bool has_bias = (param.bias_ptr != nullptr); const bool use_causal = param.use_causal; + const bool store_lse = (param.use_softmax && param.is_training); - BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] { + BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { - run_group_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(store_lse, kStoreLSE, [&] { + if constexpr(kUseSoftmax || !kStoreLSE) + { + run_group_forward_causal_softmax_bias_dropout_dispatch(param, stream); + } }); }); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_bf16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_bf16.cpp index 1f9d9d76f2..fd7c0b5f50 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_bf16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_bf16.cpp @@ -16,25 +16,32 @@ void hstu_attention_no_group_forward_bf16(HstuAttentionNoGroupFwdParams& param, { const bool has_bias = (param.bias_ptr != nullptr); const bool use_causal = param.use_causal; - BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] { + bool store_lse = (param.use_softmax && param.is_training); + + BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { - if(param.is_jagged) - run_jagged_forward_causal_softmax_bias_dropout_dispatch(param, stream); - else - run_batched_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(store_lse, kStoreLSE, [&] { + if constexpr(kUseSoftmax || !kStoreLSE) + { + if(param.is_jagged) + run_jagged_forward_causal_softmax_bias_dropout_dispatch< + ck_tile::bf16_t, + kUseCausal, + kUseSoftmax, + kStoreLSE, + kHasBias, + false, // kHasDropout + MaxK>(param, stream); + else + run_batched_forward_causal_softmax_bias_dropout_dispatch< + ck_tile::bf16_t, + kUseCausal, + kUseSoftmax, + kStoreLSE, + kHasBias, + false, // kHasDropout + MaxK>(param, stream); + } }); }); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_fp16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_fp16.cpp index 17102d4c1d..d54ecbb436 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_fp16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_group_forward_fp16.cpp @@ -16,25 +16,32 @@ void hstu_attention_no_group_forward_fp16(HstuAttentionNoGroupFwdParams& param, { const bool has_bias = (param.bias_ptr != nullptr); const bool use_causal = param.use_causal; - BOOL_SWITCH_2(has_bias, kHasBias, use_causal, kUseCausal, [&] { + bool store_lse = (param.use_softmax && param.is_training); + + BOOL_SWITCH_3(has_bias, kHasBias, use_causal, kUseCausal, param.use_softmax, kUseSoftmax, [&] { HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { - BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] { - if(param.is_jagged) - run_jagged_forward_causal_softmax_bias_dropout_dispatch(param, stream); - else - run_batched_forward_causal_softmax_bias_dropout_dispatch(param, stream); + BOOL_SWITCH(store_lse, kStoreLSE, [&] { + if constexpr(kUseSoftmax || !kStoreLSE) + { + if(param.is_jagged) + run_jagged_forward_causal_softmax_bias_dropout_dispatch< + ck_tile::fp16_t, + kUseCausal, + kUseSoftmax, + kStoreLSE, + kHasBias, + false, // kHasDropout + MaxK>(param, stream); + else + run_batched_forward_causal_softmax_bias_dropout_dispatch< + ck_tile::fp16_t, + kUseCausal, + kUseSoftmax, + kStoreLSE, + kHasBias, + false, // kHasDropout + MaxK>(param, stream); + } }); }); });