From a651ea4f7a1404b9563169474ec927d15401f310 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 18 Nov 2021 08:10:56 -0600 Subject: [PATCH] Fixed bfp16 host_conv_fwd (#52) * fixed bfloat16 issues * refactor type_convert * fixed host_convolution_forward for ushort Co-authored-by: Chao Liu --- host/driver_offline/src/conv_fwd_driver_offline.cpp | 6 +++--- host/driver_offline/src/gemm_driver_offline.cpp | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index f1ae9dc515..30a72e3bbb 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor& in, if constexpr(is_same::value) { - out(n, k, ho, wo) = type_convert(v); + out(n, k, ho, wo) = ck::type_convert(static_cast(v)); } else { @@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor& in, } if constexpr(is_same::value) { - out(n, ho, wo, k) = ck::type_convert(v); + out(n, ho, wo, k) = ck::type_convert(static_cast(v)); } else { @@ -257,7 +257,7 @@ int main(int argc, char* argv[]) using in_data_t = float; using acc_data_t = float; using out_data_t = float; -#elif 0 +#elif 1 using in_data_t = half_t; using acc_data_t = float; using out_data_t = half_t; diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index 23158b7b66..bd8cb00390 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -239,14 +239,10 @@ int main(int argc, char* argv[]) using ab_data_t = float; using acc_data_t = float; using c_data_t = float; -#elif 0 +#elif 1 using ab_data_t = half_t; using acc_data_t = float; using c_data_t = half_t; -#elif 1 - using ab_data_t = ushort; - using acc_data_t = float; - using c_data_t = ushort; #elif 1 using ab_data_t = int8_t; using acc_data_t = int32_t;