diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 341fbf8f8c..6560df75db 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -249,7 +249,7 @@ struct HstuAttentionFwdPipelineQRKSVS // reduction function for softmax const auto f_silu = [](CompDataType& x) { - auto one = ck_tile::type_convert(1.0f); + const auto one = ck_tile::type_convert(1.0f); return x = x / (one + exp(-x)); }; @@ -442,7 +442,13 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0); - const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, s)); + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, s)); + else + return cast_tile(tile_elementwise_in(p_compute_element_func, s)); + }(); move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); k_tile = load_tile(k_dram_window);