diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp index e08bed90ad..0a7a22e5ab 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp @@ -88,19 +88,18 @@ CK_TILE_DEVICE static void scale_tile_in_scalar(InOutDstrTensor& in_out_dstr_ten if constexpr(std::is_same_v, float>) { - auto tmp_scale = type_convert(scale); - constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size(); static_for<0, thread_buffer_size, 2>{}([&](auto idx) { v_mul_f32_two(in_out_dstr_tensor.thread_buf_[idx], in_out_dstr_tensor.thread_buf_[idx + 1], - tmp_scale); + scale); }); } else { - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor); + auto tmp_scale = type_convert(scale); + tile_elementwise_inout([&tmp_scale](auto& x) { x = x * tmp_scale; }, in_out_dstr_tensor); }; }; @@ -128,7 +127,8 @@ CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tenso } else { - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor); + auto tmp_scale = type_convert(scale); + tile_elementwise_inout([&tmp_scale](auto& x) { x = x * tmp_scale; }, in_out_dstr_tensor); }; };