updated bfloat16_to_float

[ROCm/composable_kernel commit: 89e1ebd4d5]
This commit is contained in:
Jing Zhang
2021-11-16 18:01:25 +00:00
parent 456f5306df
commit ea6fa92eea
8 changed files with 93 additions and 165 deletions

View File

@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if constexpr(is_same<TIn, ushort>::value)
{
v += bfloat16_to_float(in(n, c, hi, wi)) *
bfloat16_to_float(wei(k, c, y, x));
v += ck::bf16_to_f32(in(n, c, hi, wi)) *
ck::bf16_to_f32(wei(k, c, y, x));
}
else
{
@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = float_to_bfloat16(v);
out(n, k, ho, wo) = f32_to_bf16(v);
}
else
{
@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if constexpr(is_same<TIn, ushort>::value)
{
v += bfloat16_to_float(in(n, hi, wi, c)) *
bfloat16_to_float(wei(k, y, x, c));
v += ck::bf16_to_f32(in(n, hi, wi, c)) *
ck::bf16_to_f32(wei(k, y, x, c));
}
else
{
@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = float_to_bfloat16(v);
out(n, ho, wo, k) = f32_to_bf16(v);
}
else
{