mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Fix the calling context for type_context in scale_tile_in_scalar()/scale_tile_in_pack
This commit is contained in:
@@ -88,19 +88,18 @@ CK_TILE_DEVICE static void scale_tile_in_scalar(InOutDstrTensor& in_out_dstr_ten
|
||||
|
||||
if constexpr(std::is_same_v<std::remove_cv_t<DataType>, float>)
|
||||
{
|
||||
auto tmp_scale = type_convert<DataType>(scale);
|
||||
|
||||
constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
v_mul_f32_two(in_out_dstr_tensor.thread_buf_[idx],
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1],
|
||||
tmp_scale);
|
||||
scale);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
|
||||
auto tmp_scale = type_convert<DataType>(scale);
|
||||
tile_elementwise_inout([&tmp_scale](auto& x) { x = x * tmp_scale; }, in_out_dstr_tensor);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -128,7 +127,8 @@ CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tenso
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
|
||||
auto tmp_scale = type_convert<DataType>(scale);
|
||||
tile_elementwise_inout([&tmp_scale](auto& x) { x = x * tmp_scale; }, in_out_dstr_tensor);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user