From cad1356170a0723285095afe47bc056a8a62b019 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 15 Apr 2025 14:29:30 +0000 Subject: [PATCH] Use packed cast_tile for fp16 --- .../18_hstu_attention/hstu_attention_fwd_pipeline.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 341fbf8f8c..6560df75db 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -249,7 +249,7 @@ struct HstuAttentionFwdPipelineQRKSVS // reduction function for softmax const auto f_silu = [](CompDataType& x) { - auto one = ck_tile::type_convert(1.0f); + const auto one = ck_tile::type_convert(1.0f); return x = x / (one + exp(-x)); }; @@ -442,7 +442,13 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0); - const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, s)); + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, s)); + else + return cast_tile(tile_elementwise_in(p_compute_element_func, s)); + }(); move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); k_tile = load_tile(k_dram_window);