mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix the fp8 gemm for large tensors on MI300. (#1011)
* Fix the fp8 conversion
* Try clipping value before conversion
* Fix return
* Simplify with a const
* reduce the gemm input tensor values to reduce round-off error
* replace if-else with lambda
* fix syntax
---------
Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com>
[ROCm/composable_kernel commit: f46a6ffad8]
This commit is contained in:
@@ -8,8 +8,8 @@ int run_groupnorm_example()
|
||||
{
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
|
||||
Tensor<XDataType> x({M, N});
|
||||
Tensor<GammaDataType> gamma({N});
|
||||
@@ -44,9 +44,9 @@ int run_groupnorm_example()
|
||||
{0, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -65,9 +65,9 @@ int run_groupnorm_example(int argc, char* argv[])
|
||||
{0, 0, 0, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 4}, // reduction dimension: [H, W, C]
|
||||
1e-6,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -100,6 +100,8 @@ template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
|
||||
@@ -75,8 +75,8 @@ int profile_gemm_impl(int do_verification,
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 0.1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.05, 0.05});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
Reference in New Issue
Block a user