Fixed bfp16 host_conv_fwd (#52)

* fixed bfloat16 issues

* refactor type_convert

* fixed host_convolution_forward for ushort

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
zjing14
2021-11-18 08:10:56 -06:00
committed by GitHub
parent 0a66c54e95
commit a651ea4f7a
2 changed files with 4 additions and 8 deletions

View File

@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = type_convert<ushort>(v);
out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = ck::type_convert<ushort>(v);
out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v));
}
else
{
@@ -257,7 +257,7 @@ int main(int argc, char* argv[])
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 0
#elif 1
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;

View File

@@ -239,14 +239,10 @@ int main(int argc, char* argv[])
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
#elif 0
#elif 1
using ab_data_t = half_t;
using acc_data_t = float;
using c_data_t = half_t;
#elif 1
using ab_data_t = ushort;
using acc_data_t = float;
using c_data_t = ushort;
#elif 1
using ab_data_t = int8_t;
using acc_data_t = int32_t;