Fix the calling context for type_context in scale_tile_in_scalar()/scale_tile_in_pack

This commit is contained in:
Qianfeng Zhang
2026-05-11 08:53:07 +00:00
parent 0a32eddc0a
commit 8a7529177d

View File

@@ -88,19 +88,18 @@ CK_TILE_DEVICE static void scale_tile_in_scalar(InOutDstrTensor& in_out_dstr_ten
if constexpr(std::is_same_v<std::remove_cv_t<DataType>, float>)
{
auto tmp_scale = type_convert<DataType>(scale);
constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
v_mul_f32_two(in_out_dstr_tensor.thread_buf_[idx],
in_out_dstr_tensor.thread_buf_[idx + 1],
tmp_scale);
scale);
});
}
else
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
auto tmp_scale = type_convert<DataType>(scale);
tile_elementwise_inout([&tmp_scale](auto& x) { x = x * tmp_scale; }, in_out_dstr_tensor);
};
};
@@ -128,7 +127,8 @@ CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tenso
}
else
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
auto tmp_scale = type_convert<DataType>(scale);
tile_elementwise_inout([&tmp_scale](auto& x) { x = x * tmp_scale; }, in_out_dstr_tensor);
};
};