Fixed fp8 gemm (#882)

* add generic instances; fixed initi with fp8

* fixed comment

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
zjing14
2023-09-06 09:59:20 -05:00
committed by GitHub
parent aae4df5596
commit a61b8b785e
5 changed files with 106 additions and 15 deletions

View File

@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_2<ck::f8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f8_t>(tmp);
}
};
#endif
template <typename T>
struct GeneratorTensor_3
{
@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_3<ck::f8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f8_t>(fp32_tmp);
}
};
#endif
template <typename T>
struct GeneratorTensor_4
{