updated bfloat16_to_float

[ROCm/composable_kernel commit: 89e1ebd4d5]
This commit is contained in:
Jing Zhang
2021-11-16 18:01:25 +00:00
parent 456f5306df
commit ea6fa92eea
8 changed files with 93 additions and 165 deletions

View File

@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if constexpr(is_same<TIn, ushort>::value)
{
v += bfloat16_to_float(in(n, c, hi, wi)) *
bfloat16_to_float(wei(k, c, y, x));
v += ck::bf16_to_f32(in(n, c, hi, wi)) *
ck::bf16_to_f32(wei(k, c, y, x));
}
else
{
@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = float_to_bfloat16(v);
out(n, k, ho, wo) = f32_to_bf16(v);
}
else
{
@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
if constexpr(is_same<TIn, ushort>::value)
{
v += bfloat16_to_float(in(n, hi, wi, c)) *
bfloat16_to_float(wei(k, y, x, c));
v += ck::bf16_to_f32(in(n, hi, wi, c)) *
ck::bf16_to_f32(wei(k, y, x, c));
}
else
{
@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = float_to_bfloat16(v);
out(n, ho, wo, k) = f32_to_bf16(v);
}
else
{

View File

@@ -16,10 +16,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = float_to_bfloat16(v);
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -34,10 +34,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = float_to_bfloat16(v);
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -52,10 +52,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = float_to_bfloat16(v);
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -70,10 +70,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = float_to_bfloat16(v);
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -88,10 +88,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = float_to_bfloat16(v);
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -106,10 +106,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = float_to_bfloat16(v);
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -124,10 +124,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = float_to_bfloat16(v);
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
@@ -142,10 +142,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = float_to_bfloat16(v);
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(

View File

@@ -321,18 +321,14 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}
float bf16_to_f32(ushort src_val)
__host__ __device__ float bf16_to_f32(ushort src_val)
{
typedef union
union
{
ushort x, y;
float f32;
} bf16_f32_t;
bf16_f32_t v;
v.x = 0;
v.y = src_val;
return v.f32;
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
template <>
@@ -354,8 +350,7 @@ void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result
}
std::cout << "error: " << error << std::endl;
std::cout << "max_diff: " << max_diff << ", ref: " << ref_value << ", res: " << result_value
<< std::endl;
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}
#endif

View File

@@ -3,6 +3,7 @@
#include <cmath>
#include "config.hpp"
#include "data_type.hpp"
template <typename T>
struct GeneratorTensor_1
@@ -24,7 +25,7 @@ struct GeneratorTensor_1<ushort>
template <typename... Is>
ushort operator()(Is...)
{
return float_to_bfloat16(value);
return ck::f32_to_bf16(value);
}
};
@@ -74,7 +75,7 @@ struct GeneratorTensor_2<ushort>
ushort operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return float_to_bfloat16(tmp);
return ck::f32_to_bf16(tmp);
}
};
@@ -119,7 +120,7 @@ struct GeneratorTensor_3<ushort>
float fp32_tmp = min_value + tmp * (max_value - min_value);
return float_to_bfloat16(fp32_tmp);
return ck::f32_to_bf16(fp32_tmp);
}
};