Fix the fp8 gemm for large tensors on MI300. (#1011)

* Fix the fp8 conversion

* Try clipping value before conversion

* Fix return

* Simplify with a const

* reduce the gemm input tensor values to reduce round-off error

* replace if-else with lambda

* fix syntax

---------

Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com>

[ROCm/composable_kernel commit: f46a6ffad8]
This commit is contained in:
Illia Silin
2023-10-27 21:10:47 -07:00
committed by GitHub
parent 6cf76b5095
commit 60b282fb62
4 changed files with 10 additions and 8 deletions

View File

@@ -8,8 +8,8 @@ int run_groupnorm_example()
{
bool time_kernel = false;
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t M = 1024;
ck::index_t N = 1024;
Tensor<XDataType> x({M, N});
Tensor<GammaDataType> gamma({N});
@@ -44,9 +44,9 @@ int run_groupnorm_example()
{0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
{1},
1e-4,
x_dev.GetDeviceBuffer(),

View File

@@ -65,9 +65,9 @@ int run_groupnorm_example(int argc, char* argv[])
{0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
save_mean.mDesc.GetStrides().end()},
{1, 2, 4}, // reduction dimension: [H, W, C]
1e-6,
x_dev.GetDeviceBuffer(),