mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Fixed fp8 gemm (#882)
* add generic instances; fixed initi with fp8 * fixed comment --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f8_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f8_t operator()(Is...)
|
||||
{
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return ck::type_convert<ck::f8_t>(tmp);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f8_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f8_t operator()(Is...)
|
||||
{
|
||||
float tmp = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp = min_value + tmp * (max_value - min_value);
|
||||
|
||||
return ck::type_convert<ck::f8_t>(fp32_tmp);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user