diff --git a/CMakeLists.txt b/CMakeLists.txt index b28a6d9127..2c86987561 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,13 +185,22 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") add_definitions(-DCK_USE_XDL) endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") - message("Enabling FP8 gemms in ckProfiler") + message("Enabling FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + add_definitions(-DCK_USE_OCP_FP8) + set(CK_USE_OCP_FP8 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_USE_FNUZ_FP8) + set(CK_USE_FNUZ_FP8 "ON") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index c393972b42..ce5834d1e2 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -56,6 +56,14 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() + if (GPU_TARGETS MATCHES "gfx12") + add_definitions(-DCK_USE_OCP_FP8) + set(CK_USE_OCP_FP8 "ON") + endif() + if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_USE_FNUZ_FP8) + set(CK_USE_FNUZ_FP8 "ON") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 67bf92bbbc..a3a62d4cfa 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -76,7 +76,7 @@ struct ProblemSizeSplitK final struct ExecutionConfig final { // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU - int do_verification = 3; + int do_verification = 1; int init_method = 2; bool time_kernel = false; }; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index bafec3f358..3ee6e26856 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) switch(config.init_method) { case 0: - ck::utils::FillConstant{static_cast(1.f)}(a_m_k); - ck::utils::FillConstant{static_cast(1.f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.f)}(b_k_n); break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp index 8bbf8e629e..117a18e3bd 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); for(int j = 0; j < NumDMatrices; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); for(int j = 0; j < NumDMatrices; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); } } } diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index e7b2ee4173..db162fe444 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); for(int j = 0; j < NumDs; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); for(int j = 0; j < NumDs; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); } } } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index 3b3ef508ce..5bdc993192 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } - d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index c1043f419d..6806bd1886 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index c81874b066..8418c10f5e 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 7cb0588b82..64125cd1d0 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once struct ProblemSize final @@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index 90d80f9f03..277fea0272 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -175,8 +175,8 @@ int main(int argc, char* argv[]) b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); } c0_n_bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index f329146728..d545508680 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 27602e2313..1514fc48b3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -157,7 +157,7 @@ int run(int argc, char* argv[]) break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index fa76faea84..2b02069e65 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -118,7 +118,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index 2e77479bcc..e0ccb6dad1 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -153,7 +153,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 9ff4c56e06..0ad031cc71 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -178,7 +178,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index ea1e2734a6..cdfd86dff4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -152,7 +152,7 @@ int run(int argc, char* argv[]) break; default: a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); - b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 609d085299..7ac29f33ca 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -156,7 +156,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index b05915c07f..fb9b1b0bd7 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -156,7 +156,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 3fdaaebb0f..2cb69380e5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -173,7 +173,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e3690984ab..cb1d3410c9 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once struct ProblemSize final @@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index ff1282f3c7..f27dc60541 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -377,7 +377,7 @@ int main(int argc, char* argv[]) break; default: a0_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); d00_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); d01_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 8a0474156c..6af8ac6488 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -41,7 +41,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParams \ diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index a90a6340a4..392cb155cb 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -248,7 +248,7 @@ int main(int argc, char* argv[]) d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 742fd5547a..055d253042 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 809c1a956c..1ba8133ea7 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 2568754648..9b7849a654 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -205,7 +205,6 @@ int main(int argc, char* argv[]) a1_device_buf.ToDevice(a1_m_k.mData.data()); b0_device_buf.ToDevice(b0_k_n.mData.data()); b1_device_buf.ToDevice(b1_k_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -253,8 +252,6 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) { Tensor c_m_n({M, N}); diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index ea739c7071..72759916af 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any DPP examples if DL_KERNELS not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp") + message("removing dpp example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index a58acaf116..18e1db462a 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -326,7 +326,7 @@ struct Tensor std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } - void SetZero() { ck::ranges::fill(mData, 0); } + void SetZero() { ck::ranges::fill(mData, T{0}); } template void ForEach_impl(F&& f, std::vector& idx, size_t rank) diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index e87811b76b..ab9f01b53c 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,7 +37,7 @@ struct GeneratorTensor_1 float value = 1.0; template - ck::bhalf_t operator()(Is...) + ck::half_t operator()(Is...) { return ck::type_convert(value); } @@ -62,7 +62,7 @@ struct GeneratorTensor_1 float value = 1.0; template - ck::bhalf_t operator()(Is...) + ck::f8_t operator()(Is...) { return ck::type_convert(value); } @@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard } }; -template +/** + * @brief Is used to generate sequential values based on the specified dimension. + * + * @tparam T The type of the tensor values. + * @tparam Dim The specific dimension used for generation. + * + * GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor: + * + * 0 1 2 + * 0 1 2 + * 0 1 2 + * + * Essentially, the values generated are logical coordinates of the generated element that + * correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column + * indices. + * + */ +template struct GeneratorTensor_Sequential { template - float operator()(Ts... Xs) const + T operator()(Ts... Xs) const { std::array dims = {{static_cast(Xs)...}}; - return dims[Dim]; + + float tmp = dims[Dim]; + return ck::type_convert(tmp); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c1f58ccda5..a7f129b2b2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -111,8 +111,7 @@ __global__ void [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index da6b1b304e..813acfa656 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -38,8 +38,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index d4ee5c886c..5367c3d720 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #else - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(0); #endif } @@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(customized_value); } diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp new file mode 100644 index 0000000000..7b21ad6464 --- /dev/null +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -0,0 +1,988 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/random_gen.hpp" +#include "ck/utility/type.hpp" + +#ifdef CK_USE_FNUZ_FP8 +#define CK_USE_FNUZ_FP8 1 +#else +#define CK_USE_FNUZ_FP8 0 +#endif + +#ifdef CK_USE_OCP_FP8 +#define CK_USE_OCP_FP8 1 +#else +#define CK_USE_OCP_FP8 0 +#endif + +namespace ck { + +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); + +#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ + defined(__gfx1201__)) && \ + __HIP_DEVICE_COMPILE__ +#define CK_FP8_CVT_FAST_PATH 1 +#else +#define CK_FP8_CVT_FAST_PATH 0 +#endif + +#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ +#define CK_OCP_FP8_CVT_FAST_PATH 1 +#else +#define CK_OCP_FP8_CVT_FAST_PATH 0 +#endif + +typedef unsigned char fp8_storage_t; + +/** + * \brief Describes FP8 interpretation + */ +enum class ck_fp8_interpretation_t +{ + CK_E4M3_OCP = 0, // OCP E4M3 + CK_E5M2_OCP = 1, // OCP E5M2 + CK_E4M3_FNUZ = 2, // FP8 + CK_E5M2_FNUZ = 3, // BF8 +}; + +/** + * \brief Describes saturation behavior + */ +enum class ck_saturation_t +{ + CK_NOSAT = 0, // No saturation - replace with NaN or Inf + CK_SATFINITE = 1, // Saturate to finite +}; + +namespace fp8_impl { + +typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2))); +typedef float float2_t __attribute__((ext_vector_type(2))); + +__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) +{ + return static_cast(a) == 0x80; +} +__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a) +{ + return static_cast(a) == 0x80; +} + +__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a) +{ + return (a & 0x7f) == 0x7f; +} +__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a) +{ + return (a & 0x7f) > 0x7c; +} + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 +// This has been modified to handle double types as well +template +__host__ __device__ static inline T cast_from_f8(fp8_storage_t x) +{ + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, "only half, float and double are supported"); + + constexpr int weo = is_half ? 5 : (is_float ? 8 : 11); + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52); + + T fInf, fNegInf, fNaN, fNeg0, fmax, fmin; + if constexpr(is_half) + { + const unsigned short int ihInf = 0x7C00; + const unsigned short int ihNegInf = 0xFC00; + const unsigned short int ihNaN = 0x7C01; + const unsigned short int ihNeg0 = 0x8000; + /* Max number in e5m2 57344*/ + const unsigned short int ifmax = 0x7B00; + const unsigned short int ifmin = 0xFB00; + + fInf = bit_cast<_Float16>(ihInf); + fNegInf = bit_cast<_Float16>(ihNegInf); + fNaN = bit_cast<_Float16>(ihNaN); + fNeg0 = bit_cast<_Float16>(ihNeg0); + fmax = bit_cast<_Float16>(ifmax); + fmin = bit_cast<_Float16>(ifmin); + } + else if constexpr(is_float) + { + const unsigned int ifInf = 0x7F800000; + const unsigned int ifNegInf = 0xFF800000; + const unsigned int ifNaN = 0x7F800001; + const unsigned int ifNeg0 = 0x80000000; + /* Max number in e5m2 57344*/ + const unsigned int ifmax = 0x47600000; + const unsigned int ifmin = 0xC7600000; + + fInf = bit_cast(ifInf); + fNegInf = bit_cast(ifNegInf); + fNaN = bit_cast(ifNaN); + fNeg0 = bit_cast(ifNeg0); + fmax = bit_cast(ifmax); + fmin = bit_cast(ifmin); + } + else if constexpr(is_double) + { + const unsigned long long ifInf = 0x7FF0000000000000ull; + const unsigned long long ifNegInf = 0xFFF0000000000000ull; + const unsigned long long ifNaN = 0x7FF0000000000001ull; + const unsigned long long ifNeg0 = 0x8000000000000000ull; + /* Max number in e5m2 57344*/ + const unsigned long long ifmax = 0x40EC000000000000ull; + const unsigned long long ifmin = 0xC0EC000000000000ull; + + fInf = bit_cast(ifInf); + fNegInf = bit_cast(ifNegInf); + fNaN = bit_cast(ifNaN); + fNeg0 = bit_cast(ifNeg0); + fmax = bit_cast(ifmax); + fmin = bit_cast(ifmin); + } + + if(x == 0) + { + return 0; + } + + unsigned long long sign = x >> 7; + unsigned long long mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if constexpr(is_fnuz) + { + if(x == 0x80) + { + return fNaN; + } + } + else + { + if(x == 0x80) + { + return fNeg0; + } + if constexpr(we == 4) + { // e4m3 + if((x & 0x7F) == 0x7F) + { + return fNaN; + } + } + else if((x & 0x7C) == 0x7C) + { // e5m2 + if((x & 0x3) == 0) + { + if constexpr(clip) + { + return sign ? fmin : fmax; + } + return sign ? fNegInf : fInf; + } + return fNaN; + } + } + + typename __hip_internal::conditional< + sizeof(T) == 2, + unsigned short int, + typename __hip_internal::conditional:: + type>::type retval; + + if constexpr(we == 5 && is_half && !is_fnuz) + { + retval = x << 8; + return bit_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0); + + // subnormal input + if(exponent == 0) + { +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __clz(mantissa) - (32 - wm); +#else + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); +#endif + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1ull << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if constexpr(sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else if constexpr(sizeof(T) == 4) + retval = (sign << 31) | (exponent << 23) | mantissa; + else + retval = (sign << 63) | (static_cast(exponent) << 52) | mantissa; + + return bit_cast(retval); +} + +#if CK_FP8_CVT_FAST_PATH +template +static __device__ float cast_to_f32_from_f8(fp8_storage_t v) +{ + union + { + unsigned int i32val; + unsigned char i8val[4]; + } val; + val.i8val[0] = v; + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only FNUZ and OCP interpretations are supported"); + + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); + } + else + { + return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); + } +} + +template +static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) +{ + const auto i16val = bit_cast(v); + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only FNUZ and OCP interpretations are supported"); + + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false); + } + else + { + return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); + } +} + +#endif + +} // namespace fp8_impl + +struct f8_ocp_t +{ + using data_type = fp8_storage_t; + data_type data; + + static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE; + static constexpr ck_fp8_interpretation_t default_interpret = + ck_fp8_interpretation_t::CK_E4M3_OCP; + + static constexpr unsigned int we = 4; // exponent width + static constexpr unsigned int wm = 3; // mantissa width + + __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const + { + return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator float() const +#else + __host__ explicit operator float() const +#endif + { +#if CK_OCP_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32_from_f8(this->data); +#else + return fp8_impl::cast_from_f8( + this->data); // XXX: clip==false must be consistent with operator _Float16 +#endif + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator _Float16() const +#else + __host__ explicit operator _Float16() const +#endif + { +#if CK_OCP_FP8_CVT_FAST_PATH + return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); +#else + return fp8_impl::cast_from_f8<_Float16, wm, we, false>( + this->data); // XXX: clip==false must be consistent with operator float +#endif + } +}; + +struct bf8_ocp_t +{ + using data_type = fp8_storage_t; + data_type data; + + static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE; + static constexpr ck_fp8_interpretation_t default_interpret = + ck_fp8_interpretation_t::CK_E5M2_OCP; + + static constexpr unsigned int we = 5; // exponent width + static constexpr unsigned int wm = 2; // mantissa width + + __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const + { + return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator float() const + +#else + __host__ explicit operator float() const +#endif + { +#if defined(__gfx1200__) || defined(__gfx1201__) + return fp8_impl::cast_to_f32_from_f8(this->data); +#else + return fp8_impl::cast_from_f8( + this->data); // XXX: clip==false must be consistent with operator _Float16 +#endif + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator _Float16() const +#else + __host__ explicit operator _Float16() const +#endif + { +#if defined(__gfx1200__) || defined(__gfx1201__) + return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); +#else + return fp8_impl::cast_from_f8<_Float16, wm, we, false>( + this->data); // XXX: clip==false must be consistent with operator float +#endif + } +}; + +template +__host__ __device__ static inline constexpr bool fp8_is_nan(T); + +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a) +{ + return fp8_impl::ocp_f8_is_nan(a.data); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a) +{ + return fp8_impl::ocp_bf8_is_nan(a.data); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a) +{ + return fp8_impl::fnuz_f8_is_nan(a); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a) +{ + return fp8_impl::fnuz_bf8_is_nan(a); +} + +template || std::is_same_v || + std::is_same_v || std::is_same_v, + bool> = true> +__host__ __device__ static inline constexpr bool fp8_is_inf(T) +{ + return false; +} +template <> +__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a) +{ + return (a.data & 0x7f) == 0x7c; +} + +namespace fp8_impl { + +// Assertions to check for supported conversion types +#define __assert_ocp_support(interp) \ + { \ + if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \ + interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \ + { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } +#define __assert_fnuz_support(interp) \ + { \ + if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \ + interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \ + { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } + +__host__ __device__ static inline void +__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) +{ +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ +#if CK_USE_OCP_FP8 + __assert_ocp_support(interp); +#endif +#if CK_USE_FNUZ_FP8 + __assert_fnuz_support(interp); +#endif +#endif +} + +#if CK_FP8_CVT_FAST_PATH +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 +template +static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0) +{ + fp8_storage_t i8data; + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; // NOTE: not endian independent + } val; + + unsigned int ival = 0; + val.fval = v; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0); + } + } + else + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } + } + } + + if constexpr(stochastic_rounding) + { + ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) + : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + i8data = val.i8val[0]; // little endian + } + else + { // RNE CVT + ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, + val.fval, + ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + } + return i8data; +} +#endif // CK_FP8_CVT_FAST_PATH + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 +// This has been modified to add double types conversion as well +template +__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0) +{ + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, + "Only half, float and double can be cast to f8"); + + constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); + + using T_bitwise = typename __hip_internal::conditional< + sizeof(T) == 2, + unsigned short int, + typename __hip_internal::conditional:: + type>::type; + T_bitwise x_bitwise = bit_cast(_x); + + unsigned long long x{x_bitwise}; + + unsigned long long head, mantissa; + int exponent, bias; + unsigned int sign; + unsigned long long fInf, mask; + + if constexpr(sizeof(T) == 8) + { + head = x & 0xFFF0000000000000ull; + mantissa = x & 0xFFFFFFFFFFFFFull; + exponent = (head >> 52) & 0x7FF; + sign = head >> 63; + bias = 1023; + fInf = 0x7FF0000000000000ull; + mask = 0x7FFFFFFFFFFFFFFFull; + } + else if constexpr(sizeof(T) == 4) + { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + fInf = 0x7F800000; + mask = 0x7FFFFFFF; + } + else + { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + fInf = 0x7C00; + mask = 0x7FFF; + } + unsigned int signed_inf = 0; + unsigned int nan = 0; + if constexpr(is_fnuz) + { + signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80; + nan = 0x80; + } + else + { + if constexpr(we == 4) + { // e4m3 + signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); + } + else + { // e5m2 + signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c); + } + nan = (sign << 7) + 0x7f; + } + // Max values + unsigned long long ifmax = 0; + if constexpr(sizeof(T) == 8) + { + if constexpr(we == 5) + { // 57344 + ifmax = 0x40EC000000000000ull; + } + else + { + if constexpr(is_fnuz) + { // 240 + ifmax = 0x406E000000000000ull; + } + else + { // 448 + ifmax = 0x407C000000000000ull; + } + } + } + else if(sizeof(T) == 4) + { + if constexpr(we == 5) + { + ifmax = 0x47600000; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x43700000; + } + else + { + ifmax = 0x43E00000; + } + } + } + else + { + if constexpr(we == 5) + { + ifmax = 0x7B00; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x5B80; + } + else + { + ifmax = 0x5F00; + } + } + } + // Deal with inf and NaNs + if((x & fInf) == fInf) + { + if constexpr(is_fnuz) + return signed_inf; + + return mantissa != 0 ? nan : signed_inf; + } + + if((x & mask) > ifmax) + { + return signed_inf; + } + + if(x == 0) + { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + mostly concern fp16 here. In this case, f8 is usually in denormal. But there + could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has + exponent bias 16. It means that there are some numbers in fp16 denormal but they + are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers + where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 + (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implicit 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, act_exponent could be larger. Just + // that it does not need shift mantissa + } + mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) == + (1ull << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part and + make something not midpoint look like midpoint. For example, the fp16 number + 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right + by 4 bits, it would look like midpoint. + */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1; + bool odd = + mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0) + { + if((1ull << mfmt) & mantissa) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } + else + { + if((1ull << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - 1; + if(f8_exponent > max_exp) + { + if constexpr(clip) + { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + if(f8_exponent == 0 && mantissa == 0) + return is_fnuz ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +/** + * \brief convert float to @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \param f float number + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; + rng = prand_generator(reinterpret_cast(&f), f); + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ +#else +__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ +#endif + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; + rng = prand_generator(reinterpret_cast(&f), f); + } + + if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP) + { + return cast_to_f8(f, rng); + } + else + { + __hip_assert(false && "FP8 type is not supported by current target device"); + return 0; + } +#endif // CK_FP8_CVT_FAST_PATH +} + +/** + * \brief convert _Float16 to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x _Float16 value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) +#else +__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) +#endif +{ + return cvt_float_to_fp8(static_cast(x)); +} + +} // namespace fp8_impl + +// Declare a template function for fp8 conversion using RNE +template +__host__ __device__ constexpr Y f8_convert_rne(X x); + +// convert fp32 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +// convert fp32 to bf8 with rounding to nearest even +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +// convert _Float16 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(_Float16 x) +{ + return f8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(_Float16 x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +// Declare a template function for fp8 conversion using RNE +template +__host__ __device__ constexpr Y f8_convert_sr(X x); + +// convert fp32 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +// convert fp32 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) +{ + return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +// convert _Float16 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(_Float16 x) +{ + return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +// convert _Float16 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(_Float16 x) +{ + return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +#if CK_USE_OCP_FP8 +using f8_t = f8_ocp_t; +using bf8_t = bf8_ocp_t; +#define CK_FP8_TYPE_FNUZ 0 +#define CK_FP8_TYPE_OCP 1 +#else +using f8_t = f8_fnuz_t; +using bf8_t = bf8_fnuz_t; +#define CK_FP8_TYPE_FNUZ 1 +#define CK_FP8_TYPE_OCP 0 +#endif + +} // namespace ck diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a955279bc8..5a7030cca7 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -4,7 +4,7 @@ #pragma once namespace ck { -// Define the common macro for gfx94x models +// Define the common macro for MI300 models #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 39f532e0e9..a7dc071bc2 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/utility/amd_ck_fp8.hpp" #include "ck/utility/statically_indexed_array.hpp" namespace ck { @@ -10,8 +11,6 @@ namespace ck { using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); -using f8_t = _BitInt(8); -using bf8_t = unsigned _BitInt(8); inline constexpr auto next_pow2(uint32_t x) { @@ -19,14 +18,15 @@ inline constexpr auto next_pow2(uint32_t x) return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; } -// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, +// native types: bool template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value; + is_same::value || is_same::value || + is_same::value || is_same::value; } // vector_type @@ -166,16 +166,30 @@ struct scalar_type #endif template <> -struct scalar_type +struct scalar_type { - using type = f8_t; + using type = f8_fnuz_t; static constexpr index_t vector_size = 1; }; template <> -struct scalar_type +struct scalar_type { - using type = bf8_t; + using type = bf8_fnuz_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f8_ocp_t::data_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; @@ -1010,60 +1024,203 @@ struct vector_type()>> } }; -template -struct non_native_vector_base +template +struct non_native_vector_base; + +template +struct nnvb_data_t_selector { - using type = non_native_vector_base; + using type = unsigned _BitInt(8 * sizeof(T)); +}; - __host__ __device__ non_native_vector_base() = default; - __host__ __device__ non_native_vector_base(const type&) = default; - __host__ __device__ non_native_vector_base(type&&) = default; - __host__ __device__ ~non_native_vector_base() = default; +template <> +struct nnvb_data_t_selector +{ + using type = f8_ocp_t::data_type; +}; +template <> +struct nnvb_data_t_selector +{ + using type = bf8_ocp_t::data_type; +}; - T d[N]; +template +struct non_native_vector_base< + T, + N, + std::enable_if_t> +{ + using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = data_t __attribute__((ext_vector_type(N))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } +}; + +template +struct scalar_type>; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; }; // non-native vector_type implementation template struct vector_type()>> { - using d1_t = T; - using type = d1_t; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using type = d1_nnv_t; union alignas(next_pow2(1 * sizeof(T))) { d1_t d1_; StaticallyIndexedArray d1x1_; + d1_nnv_t d1_nnv_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{}} {} + __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value, + static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - return data_.d1x1_; + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } } template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value, + static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - return data_.d1x1_; + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } } }; template struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; using type = d2_t; @@ -1081,10 +1238,11 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x2_; } @@ -1101,10 +1259,11 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x2_; } @@ -1122,9 +1281,10 @@ struct vector_type()>> template struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; using type = d4_t; @@ -1143,10 +1303,11 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x4_; } @@ -1167,10 +1328,11 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x4_; } @@ -1192,10 +1354,11 @@ struct vector_type()>> template struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; using type = d8_t; @@ -1215,11 +1378,12 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x8_; } @@ -1244,11 +1408,12 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x8_; } @@ -1274,11 +1439,12 @@ struct vector_type()>> template struct vector_type()>> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; using type = d16_t; @@ -1299,12 +1465,12 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x16_; } @@ -1333,12 +1499,12 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); - if constexpr(is_same::value) + if constexpr(is_same::value || is_same::value) { return data_.d1x16_; } @@ -1632,20 +1798,70 @@ using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; // f8 -using f8x2_t = typename vector_type::type; -using f8x4_t = typename vector_type::type; -using f8x8_t = typename vector_type::type; -using f8x16_t = typename vector_type::type; -using f8x32_t = typename vector_type::type; -using f8x64_t = typename vector_type::type; +using f8x2_fnuz_t = typename vector_type::type; +using f8x4_fnuz_t = typename vector_type::type; +using f8x8_fnuz_t = typename vector_type::type; +using f8x16_fnuz_t = typename vector_type::type; +using f8x32_fnuz_t = typename vector_type::type; +using f8x64_fnuz_t = typename vector_type::type; // bf8 -using bf8x2_t = typename vector_type::type; -using bf8x4_t = typename vector_type::type; -using bf8x8_t = typename vector_type::type; -using bf8x16_t = typename vector_type::type; -using bf8x32_t = typename vector_type::type; -using bf8x64_t = typename vector_type::type; +using bf8x2_fnuz_t = typename vector_type::type; +using bf8x4_fnuz_t = typename vector_type::type; +using bf8x8_fnuz_t = typename vector_type::type; +using bf8x16_fnuz_t = typename vector_type::type; +using bf8x32_fnuz_t = typename vector_type::type; +using bf8x64_fnuz_t = typename vector_type::type; + +// f8 +using f8x2_ocp_t = typename vector_type::type; +using f8x4_ocp_t = typename vector_type::type; +using f8x8_ocp_t = typename vector_type::type; +using f8x16_ocp_t = typename vector_type::type; +using f8x32_ocp_t = typename vector_type::type; +using f8x64_ocp_t = typename vector_type::type; + +// bf8 +using bf8x2_ocp_t = typename vector_type::type; +using bf8x4_ocp_t = typename vector_type::type; +using bf8x8_ocp_t = typename vector_type::type; +using bf8x16_ocp_t = typename vector_type::type; +using bf8x32_ocp_t = typename vector_type::type; +using bf8x64_ocp_t = typename vector_type::type; + +#if CK_FP8_TYPE_OCP +// f8 +using f8x2_t = f8x2_ocp_t; +using f8x4_t = f8x4_ocp_t; +using f8x8_t = f8x8_ocp_t; +using f8x16_t = f8x16_ocp_t; +using f8x32_t = f8x32_ocp_t; +using f8x64_t = f8x64_ocp_t; + +// bf8 +using bf8x2_t = bf8x2_ocp_t; +using bf8x4_t = bf8x4_ocp_t; +using bf8x8_t = bf8x8_ocp_t; +using bf8x16_t = bf8x16_ocp_t; +using bf8x32_t = bf8x32_ocp_t; +using bf8x64_t = bf8x64_ocp_t; +#elif CK_FP8_TYPE_FNUZ +// f8 +using f8x2_t = f8x2_fnuz_t; +using f8x4_t = f8x4_fnuz_t; +using f8x8_t = f8x8_fnuz_t; +using f8x16_t = f8x16_fnuz_t; +using f8x32_t = f8x32_fnuz_t; +using f8x64_t = f8x64_fnuz_t; + +// bf8 +using bf8x2_t = bf8x2_fnuz_t; +using bf8x4_t = bf8x4_fnuz_t; +using bf8x8_t = bf8x8_fnuz_t; +using bf8x16_t = bf8x16_fnuz_t; +using bf8x32_t = bf8x32_fnuz_t; +using bf8x64_t = bf8x64_fnuz_t; +#endif // u8 using uint8x2_t = typename vector_type::type; @@ -1702,7 +1918,7 @@ struct NumericLimits #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> -struct NumericLimits +struct NumericLimits { // negative zero nan mode with exp bias = 8 static constexpr uint8_t binary_min = 0x08; // 0b00001000 @@ -1715,17 +1931,17 @@ struct NumericLimits // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } - __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } - __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } - __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } }; template <> -struct NumericLimits +struct NumericLimits { // negative zero nan mode with exp bias = 16 static constexpr uint8_t binary_min = 0x04; // 0b00000100 @@ -1738,13 +1954,59 @@ struct NumericLimits // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } - __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } - __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } - __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } }; template @@ -1787,7 +2049,7 @@ struct NumericUtils }; template <> -struct NumericUtils +struct NumericUtils { static constexpr int exp = 4; static constexpr int mant = 3; @@ -1796,13 +2058,28 @@ struct NumericUtils }; template <> -struct NumericUtils +struct NumericUtils { static constexpr int exp = 5; static constexpr int mant = 2; static constexpr int bias = 16; // negative zero nan mode // static constexpr int bias = 15; // ieee mode }; +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 7; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 15; +}; template <> struct NumericUtils diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index b374c4ad55..a6c3540d85 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; -static inline __host__ bool isnan(f8_t x) { return (x & 0x80); }; +static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); }; #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 static inline __host__ bool isnan(int4_t x) @@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; -static inline __device__ bool isnan(f8_t x) { return (x & 0x80); }; +static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); }; static inline __device__ half_t sqrt(half_t x) { diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp index b7edf26507..4ea52f7eb0 100644 --- a/include/ck/utility/random_gen.hpp +++ b/include/ck/utility/random_gen.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/ck.hpp" + namespace ck { // Pseudo random number generator @@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = } // version for fp16 -template {}, bool> = false> +template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { uint16_t x = *(reinterpret_cast(&val)); @@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = } // return 0 if data is not fp16 or fp32 -template {} || std::is_same{}), bool> = false> +template < + typename T, + uint32_t seed_t, + std::enable_if_t{} || std::is_same<_Float16, T>{}), bool> = false> __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) { std::ignore = id; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 87fa9aa38a..f372756e68 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -9,7 +9,7 @@ #include "ck/utility/array.hpp" namespace ck { -// Define the common macro for gfx94x models +// Define the common macro for MI300 models #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif @@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +template <> +inline __host__ __device__ constexpr f8_ocp_t type_convert(int x) +{ + return f8_ocp_t{type_convert(x)}; +} + +template <> +inline __host__ __device__ constexpr bf8_ocp_t type_convert(int x) +{ + return bf8_ocp_t{type_convert(x)}; +} + // Convert X to Y template __host__ __device__ constexpr Y type_convert_sp(X x) @@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); // convert fp32 to fp8 with stochastic rounding template <> -inline __host__ __device__ f8_t f8_convert_sr(float x) +inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) { constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); @@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; return utils:: - cast_to_f8(x, - rng); + cast_to_f8( + x, rng); #endif } // convert fp16 to fp8 with stochastic rounding template <> -inline __host__ __device__ f8_t f8_convert_sr(half_t x) +inline __host__ __device__ f8_fnuz_t f8_convert_sr(half_t x) { #if defined(__gfx94__) // convert to float and use native converion - return f8_convert_sr(type_convert(x)); + return f8_convert_sr(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); - return utils:: - cast_to_f8( - x, rng); + return utils::cast_to_f8(x, rng); #endif } // convert fp32 to bf8 with stochastic rounding template <> -inline __host__ __device__ bf8_t f8_convert_sr(float x) +inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) { constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); @@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr(float x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - return utils:: - cast_to_f8( - x, rng); + return utils::cast_to_f8(x, rng); #endif } // convert fp16 to bf8 with stochastic rounding template <> -inline __host__ __device__ bf8_t f8_convert_sr(half_t x) +inline __host__ __device__ bf8_fnuz_t f8_convert_sr(half_t x) { #if defined(__gfx94__) // convert to float and use native converion - return f8_convert_sr(type_convert(x)); + return f8_convert_sr(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); - return utils:: - cast_to_f8( - x, rng); + return utils::cast_to_f8(x, rng); #endif } @@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); // convert fp32 to fp8 with rounding to nearest even template <> -inline __host__ __device__ f8_t f8_convert_rne(float x) +inline __host__ __device__ f8_fnuz_t f8_convert_rne(float x) { #if defined(__gfx94__) union @@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne(float x) constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; return utils:: - cast_to_f8(x, - rng); + cast_to_f8( + x, rng); #endif } // convert fp16 to fp8 with rounding to nearest even template <> -inline __host__ __device__ f8_t f8_convert_rne(half_t x) +inline __host__ __device__ f8_fnuz_t f8_convert_rne(half_t x) { #if defined(__gfx94__) // convert to float and use native converion - return f8_convert_rne(type_convert(x)); + return f8_convert_rne(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils:: - cast_to_f8( - x, rng); + return utils::cast_to_f8(x, rng); #endif } // convert fp32 to bf8 with rounding to nearest even template <> -inline __host__ __device__ bf8_t f8_convert_rne(float x) +inline __host__ __device__ bf8_fnuz_t f8_convert_rne(float x) { #if defined(__gfx94__) union @@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne(float x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils:: - cast_to_f8( - x, rng); + return utils::cast_to_f8(x, rng); #endif } // convert fp16 to bf8 with rounding to nearest even template <> -inline __host__ __device__ bf8_t f8_convert_rne(half_t x) +inline __host__ __device__ bf8_fnuz_t f8_convert_rne(half_t x) { #if defined(__gfx94__) // convert to float and use native converion - return f8_convert_rne(type_convert(x)); + return f8_convert_rne(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils:: - cast_to_f8( - x, rng); + return utils::cast_to_f8(x, rng); #endif } // convert fp32 to fp8 template <> -inline __host__ __device__ f8_t type_convert(float x) +inline __host__ __device__ f8_fnuz_t type_convert(float x) { #if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); + return f8_convert_sr(x); #else - return f8_convert_rne(x); + return f8_convert_rne(x); +#endif +} + +// convert fp32 to fp8 +template <> +inline __host__ __device__ f8_ocp_t type_convert(float x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); #endif } // convert fp8 to fp32 template <> -inline __host__ __device__ float type_convert(f8_t x) +inline __host__ __device__ float type_convert(f8_fnuz_t x) { #if defined(__gfx94__) float fval; @@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert(f8_t x) return fval; #else constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); #endif } template <> -inline __host__ __device__ float2_t type_convert(f8x2_t x) +inline __host__ __device__ float2_t type_convert(f8x2_fnuz_t x) { #if defined(__gfx94__) const auto i16val = bit_cast(x); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); #else constexpr bool negative_zero_nan = true; - const auto f8x2_v = vector_type(x); + const auto f8x2_v = vector_type(x); vector_type f32x2_v; f32x2_v.template AsType()(Number<0>{}) = - utils::cast_from_f8( - f8x2_v.template AsType()[Number<0>{}]); + utils::cast_from_f8( + f8x2_v.template AsType()[Number<0>{}]); f32x2_v.template AsType()(Number<1>{}) = - utils::cast_from_f8( - f8x2_v.template AsType()[Number<1>{}]); + utils::cast_from_f8( + f8x2_v.template AsType()[Number<1>{}]); return f32x2_v.template AsType()[Number<0>{}]; #endif } +template <> +inline __host__ __device__ float2_t type_convert(f8x2_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32x2_from_f8x2( + x.AsType()[Number<0>{}]); +#else + return float2_t{fp8_impl::cast_from_f8( + x.AsType()[Number<0>{}]), + fp8_impl::cast_from_f8( + x.AsType()[Number<1>{}])}; +#endif +} + template <> inline __host__ __device__ half2_t type_convert(float2_t x) { @@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert(float2_t x) // convert fp16 to fp8 template <> -inline __host__ __device__ f8_t type_convert(half_t x) +inline __host__ __device__ f8_fnuz_t type_convert(half_t x) { #if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); + return f8_convert_sr(x); #else - return f8_convert_rne(x); + return f8_convert_rne(x); +#endif +} + +// convert fp16 to fp8 +template <> +inline __host__ __device__ f8_ocp_t type_convert(half_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); #endif } // convert fp8 to fp16 template <> -inline __host__ __device__ half_t type_convert(f8_t x) +inline __host__ __device__ half_t type_convert(f8_fnuz_t x) { #if defined(__gfx94__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); #else constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); #endif } // convert fp32 to bf8 template <> -inline __host__ __device__ bf8_t type_convert(float x) +inline __host__ __device__ bf8_fnuz_t type_convert(float x) { #if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); + return f8_convert_sr(x); #else - return f8_convert_rne(x); + return f8_convert_rne(x); +#endif +} + +// convert fp32 to bf8 +template <> +inline __host__ __device__ bf8_ocp_t type_convert(float x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); #endif } // convert bf8 to fp32 template <> -inline __host__ __device__ float type_convert(bf8_t x) +inline __host__ __device__ float type_convert(bf8_fnuz_t x) { #if defined(__gfx94__) float fval; @@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert(bf8_t x) return fval; #else constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); #endif } // convert fp16 to bf8 template <> -inline __host__ __device__ bf8_t type_convert(half_t x) +inline __host__ __device__ bf8_fnuz_t type_convert(half_t x) { #if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); + return f8_convert_sr(x); #else - return f8_convert_rne(x); + return f8_convert_rne(x); +#endif +} + +// convert fp16 to bf8 +template <> +inline __host__ __device__ bf8_ocp_t type_convert(half_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); #endif } // convert bf8 to fp16 template <> -inline __host__ __device__ half_t type_convert(bf8_t x) +inline __host__ __device__ half_t type_convert(bf8_fnuz_t x) { #if defined(__gfx94__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); #else constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); #endif } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index e1edc4fae0..1ae11fe9db 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_m_k_.mDesc.GetLengths()[1]; - AccDataType v_acc = 0; - ComputeTypeA v_a = 0; - ComputeTypeB v_b = 0; + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; for(int k = 0; k < K; ++k) { @@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } - CDataType v_c = 0; + CDataType v_c{0}; arg.c_element_op_(v_c, v_acc); diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 2c0b6c7b75..dd023e6b51 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -62,7 +62,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build mha instances if gfx94 or gfx90a targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha") message("removing mha instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -346,7 +346,7 @@ if(CK_DEVICE_CONV_INSTANCES) endif() if(CK_DEVICE_MHA_INSTANCES) set(gpu_list ${INST_TARGETS}) - if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a") + if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a") add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES}) add_library(composablekernels::device_mha_operations ALIAS device_mha_operations) target_compile_features(device_mha_operations PUBLIC) diff --git a/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp index af31cf8a86..e31433cc81 100644 --- a/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f8_instance.cpp @@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances( instances) { add_device_operation_instances( - instances, device_pool3d_fwd_ndhwc_instances{}); + instances, device_pool3d_fwd_ndhwc_instances{}); } void add_device_pool3d_fwd_ndhwc_index_f8_instances( @@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances( instances) { add_device_operation_instances( - instances, device_pool3d_fwd_ndhwc_instances{}); + instances, device_pool3d_fwd_ndhwc_instances{}); } } // namespace instance diff --git a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp index 5bee67c1ce..be69b67b5c 100644 --- a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, break; default: a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); - b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); } diff --git a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp index f3d2c55617..b585b7d56a 100644 --- a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp index 15a21206c5..700ada73a1 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp index f2fcb0b133..e3c462e21c 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, break; default: a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); - b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 0419ccd8e7..1373dbc497 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification, switch(init_method) { case 0: - ck::utils::FillConstant{static_cast(1.f)}(a_m_k); - ck::utils::FillConstant{static_cast(1.f)}(b_k_n); + ck::utils::FillConstant{type_convert(1.f)}(a_m_k); + ck::utils::FillConstant{type_convert(1.f)}(b_k_n); break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index a783be7bb0..a9d3dad7f3 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4) endif() endif() -add_gtest_executable(test_fp8 test_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_fp8 PRIVATE utility) + + +add_custom_target(test_fp8) + +if (CK_USE_OCP_FP8) + add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp) + if(result EQUAL 0) + target_link_libraries(test_fp8_ocp PRIVATE utility) + endif() + + add_gtest_executable(test_bf8_ocp test_bf8_ocp.cpp) + if(result EQUAL 0) + target_link_libraries(test_bf8_ocp PRIVATE utility) + endif() + + add_dependencies(test_fp8 test_fp8_ocp) + add_dependencies(test_fp8 test_bf8_ocp) endif() -add_gtest_executable(test_bf8 test_bf8.cpp) -if(result EQUAL 0) - target_link_libraries(test_bf8 PRIVATE utility) + +if (CK_USE_FNUZ_FP8) + add_gtest_executable(test_fp8_fnuz test_fp8_fnuz.cpp) + if(result EQUAL 0) + target_link_libraries(test_fp8_fnuz PRIVATE utility) + endif() + + add_gtest_executable(test_bf8_fnuz test_bf8_fnuz.cpp) + if(result EQUAL 0) + target_link_libraries(test_bf8_fnuz PRIVATE utility) + endif() + + add_dependencies(test_fp8 test_fp8_fnuz) + add_dependencies(test_fp8 test_bf8_fnuz) endif() add_gtest_executable(test_custom_type test_custom_type.cpp) diff --git a/test/data_type/test_bf8.cpp b/test/data_type/test_bf8_fnuz.cpp similarity index 52% rename from test/data_type/test_bf8.cpp rename to test/data_type/test_bf8_fnuz.cpp index 6f50db68c7..4ff796a614 100644 --- a/test/data_type/test_bf8.cpp +++ b/test/data_type/test_bf8_fnuz.cpp @@ -5,158 +5,169 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" -using ck::bf8_t; +using ck::bf8_fnuz_t; using ck::f8_convert_rne; using ck::f8_convert_sr; using ck::half_t; using ck::type_convert; -TEST(BF8, NumericLimits) +TEST(BF8FNUZ, NumericLimits) { // constants given for negative zero nan mode - EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x04)); - EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); - EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); - EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); + EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x04)); + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); + EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); } -TEST(BF8, ConvertFP32Nearest) +TEST(BF8FNUZ, ConvertFP32Nearest) { // fix the tolerance value float abs_tol = 1e-6; // convert 0 float to bf8 and back, check if holds - ASSERT_NEAR(0.0f, type_convert(f8_convert_rne(0.0f)), abs_tol); + ASSERT_NEAR(0.0f, type_convert(f8_convert_rne(0.0f)), abs_tol); // don't run the next test on gfx11 devices #ifndef CK_SKIP_FLAKY_F8_TEST // convert minimal float to bf8 and back, check if holds ASSERT_NEAR(std::numeric_limits::min(), - type_convert(f8_convert_rne(std::numeric_limits::min())), + type_convert(f8_convert_rne(std::numeric_limits::min())), abs_tol); #endif - // convert maximal bf8_t to float and check if equal to 57344.0 - ASSERT_NEAR(57344.0f, type_convert(f8_convert_rne(57344.0f)), abs_tol); + + const auto max_bf8_t_float = type_convert(ck::NumericLimits::Max()); + // convert maximal bf8_fnuz_t to float and check if equal to 57344.0 + ASSERT_NEAR( + max_bf8_t_float, type_convert(f8_convert_rne(max_bf8_t_float)), abs_tol); // convert maximal float to bf8 and back, check if clipped to 57344.0 - ASSERT_NEAR(57344.0f, - type_convert(f8_convert_rne(std::numeric_limits::max())), + ASSERT_NEAR(max_bf8_t_float, + type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); - // convert inf float to bf8_t and check if it is qNan - ASSERT_NEAR(type_convert(0x80), - f8_convert_rne(std::numeric_limits::infinity()), + // convert inf float to bf8_fnuz_t and check if it is qNan + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity()), abs_tol); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); // negative norm float value to bf8 and back, check if holds float neg_float = -0.0000610351f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); // positive subnorm float value to bf8 and back, check if holds pos_float = 0.0000305175f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); // negative subnorm float value to bf8 and back, check if holds neg_float = -0.0000152587f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); } -TEST(BF8, ConvertFP32Stochastic) +TEST(BF8FNUZ, ConvertFP32Stochastic) { // fix the tolerance value float abs_tol = 1e-6; // convert 0 float to bf8 and back, check if holds - ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); // convert minimal float to bf8 and back, check if holds ASSERT_NEAR(std::numeric_limits::min(), - type_convert(f8_convert_sr(std::numeric_limits::min())), + type_convert(f8_convert_sr(std::numeric_limits::min())), abs_tol); - // convert maximal bf8_t to float and check if equal to 57344.0 - ASSERT_NEAR(57344.0f, type_convert(f8_convert_sr(57344.0f)), abs_tol); + + const auto max_bf8_t_float = type_convert(ck::NumericLimits::Max()); + // convert maximal bf8_fnuz_t to float and check if equal to 57344.0 + ASSERT_NEAR( + max_bf8_t_float, type_convert(f8_convert_sr(max_bf8_t_float)), abs_tol); // convert maximal float to bf8 and back, check if clipped to 57344.0 - ASSERT_NEAR(57344.0f, - type_convert(f8_convert_sr(std::numeric_limits::max())), + ASSERT_NEAR(max_bf8_t_float, + type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); - // convert inf float to bf8_t and check if it is qNan - ASSERT_NEAR(type_convert(0x80), - f8_convert_sr(std::numeric_limits::infinity()), + // convert inf float to bf8_fnuz_t and check if it is qNan + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity()), abs_tol); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); // negative norm float value to bf8 and back, check if holds float neg_float = -0.0000610351f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); // positive subnorm float value to bf8 and back, check if holds pos_float = 0.0000305175f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); // negative subnorm float value to bf8 and back, check if holds neg_float = -0.0000152587f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); } -TEST(BF8, ConvertFP16Nearest) +TEST(BF8FNUZ, ConvertFP16Nearest) { // fix the tolerance value float abs_tol = 1e-3; // convert 0 fp16 to bf8 and back, check if holds - ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_rne(half_t{0.0})), abs_tol); + ASSERT_NEAR( + half_t{0.0}, type_convert(f8_convert_rne(half_t{0.0})), abs_tol); // convert minimal fp16 to bf8 and back, check if holds ASSERT_NEAR(ck::NumericLimits::Min(), - type_convert(f8_convert_rne(ck::NumericLimits::Min())), + type_convert(f8_convert_rne(ck::NumericLimits::Min())), abs_tol); - // convert maximal bf8_t to fp16 and check if equal to 57344.0 + + const auto max_bf8_t_half = type_convert(ck::NumericLimits::Max()); + // convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0 ASSERT_NEAR( - half_t{57344.0}, type_convert(f8_convert_rne(half_t{57344.0})), abs_tol); + max_bf8_t_half, type_convert(f8_convert_rne(max_bf8_t_half)), abs_tol); // convert maximal fp16 to bf8 and back, check if clipped to 57344.0 - ASSERT_NEAR(half_t{57344.0}, - type_convert(f8_convert_rne(ck::NumericLimits::Max())), + ASSERT_NEAR(max_bf8_t_half, + type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); - // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN - ASSERT_NEAR(type_convert(0x80), - f8_convert_rne(ck::NumericLimits::QuietNaN()), + // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN()), abs_tol); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); // negative norm fp16 value to bf8 and back, check if holds half_t neg_half = half_t{-0.0000610351}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); // positive subnorm fp16 value to bf8 and back, check if holds pos_half = half_t{0.0000305175}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); // negative subnorm fp16 value to bf8 and back, check if holds neg_half = half_t{-0.0000152587}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); } -TEST(BF8, ConvertFP16Stochastic) +TEST(BF8FNUZ, ConvertFP16Stochastic) { // fix the tolerance value float abs_tol = 1e-3; // convert 0 fp16 to bf8 and back, check if holds - ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); // convert minimal fp16 to bf8 and back, check if holds ASSERT_NEAR(ck::NumericLimits::Min(), - type_convert(f8_convert_sr(ck::NumericLimits::Min())), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), abs_tol); - // convert maximal bf8_t to fp16 and check if equal to 57344.0 + + const auto max_bf8_t_half = type_convert(ck::NumericLimits::Max()); + // convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0 ASSERT_NEAR( - half_t{57344.0}, type_convert(f8_convert_sr(half_t{57344.0})), abs_tol); + max_bf8_t_half, type_convert(f8_convert_sr(max_bf8_t_half)), abs_tol); // convert maximal fp16 to bf8 and back, check if clipped to 57344.0 - ASSERT_NEAR(half_t{57344.0}, - type_convert(f8_convert_sr(ck::NumericLimits::Max())), + ASSERT_NEAR(max_bf8_t_half, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); - // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN - ASSERT_NEAR(type_convert(0x80), - f8_convert_sr(ck::NumericLimits::QuietNaN()), + // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); // negative norm fp16 value to bf8 and back, check if holds half_t neg_half = half_t{-0.0000610351}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); // positive subnorm fp16 value to bf8 and back, check if holds pos_half = half_t{0.0000305175}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); // negative subnorm fp16 value to bf8 and back, check if holds neg_half = half_t{-0.0000152587}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); } diff --git a/test/data_type/test_bf8_ocp.cpp b/test/data_type/test_bf8_ocp.cpp new file mode 100644 index 0000000000..9d4ee38b15 --- /dev/null +++ b/test/data_type/test_bf8_ocp.cpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::bf8_ocp_t; +using ck::f8_convert_rne; +using ck::f8_convert_sr; +using ck::half_t; +using ck::type_convert; + +TEST(BF8OCP, NumericLimits) +{ // constants given for OCP FP8 + EXPECT_EQ(ck::NumericLimits::Min(), + type_convert(0x04)); // 0b00000100 = 2^-14 + EXPECT_EQ(ck::NumericLimits::Max(), + type_convert(0x7B)); // 0b01111011 = 57344 + EXPECT_EQ(ck::NumericLimits::Lowest(), + type_convert(0xFB)); // 0b11111011 = -57344 + EXPECT_EQ(ck::NumericLimits::QuietNaN().data, + type_convert(0x7D).data); // 0b01111101 + EXPECT_FALSE(ck::NumericLimits::QuietNaN() == + ck::NumericLimits::QuietNaN()); + EXPECT_TRUE(ck::fp8_is_inf(type_convert(0xFC)) && + ck::fp8_is_inf(type_convert(0x7C))); +} + +TEST(BF8OCP, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + + // convert 0 float to bfp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_rne(0.0f)), 0.0f); + + // convert minimal float to bf8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_rne(std::numeric_limits::min())), + abs_tol); + + const auto max_bf8_t_float = type_convert(ck::NumericLimits::Max()); + + // convert maximal bf8_ocp_t to float and check if equal to bf8 max + ASSERT_NEAR( + max_bf8_t_float, type_convert(f8_convert_rne(max_bf8_t_float)), 0.0f); + + // convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite) + ASSERT_NEAR(max_bf8_t_float, + type_convert(f8_convert_rne(std::numeric_limits::max())), + 0.0f); + + // convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ(ck::NumericLimits::Max(), + f8_convert_rne(std::numeric_limits::infinity())); + + // positive normal float value to bf8 and back, check if holds + float pos_float = 0.0000762939f; // 10*2^-17 + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + + // negative smallest normal bf8 value to bf8 and back, check if holds + constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14 + ASSERT_NEAR(neg_min_bf8, type_convert(f8_convert_rne(neg_min_bf8)), 0.0f); + + // positive subnorm float value to bf8 and back, check if holds + constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15 + ASSERT_NEAR( + pos_subnorm_bf8, type_convert(f8_convert_rne(pos_subnorm_bf8)), 0.0f); + + // min subnorm bf8 value to bf8 and back, check if holds + constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16 + ASSERT_NEAR( + min_subnorm_bf8, type_convert(f8_convert_rne(min_subnorm_bf8)), 0.0f); + + // smaller than min subnorm bf8 value to bf8 must be zero + constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17 + ASSERT_EQ(0.0f, type_convert(f8_convert_rne(less_than_min_subnorm))); + + // convert quiet NaN to bf8_ocp_t and check if it is quiet NaN + const auto bf8_nan = f8_convert_rne(std::numeric_limits::quiet_NaN()); + ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); +} + +TEST(BF8OCP, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + + // convert 0 float to bfp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), 0.0f); + + // convert minimal float to bf8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + + const auto max_bf8_t_float = type_convert(ck::NumericLimits::Max()); + + // convert maximal bf8_ocp_t to float and check if equal to bf8 max + ASSERT_NEAR( + max_bf8_t_float, type_convert(f8_convert_sr(max_bf8_t_float)), 0.0f); + + // convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite) + ASSERT_NEAR(max_bf8_t_float, + type_convert(f8_convert_sr(std::numeric_limits::max())), + 0.0f); + + // convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ(ck::NumericLimits::Max(), + f8_convert_sr(std::numeric_limits::infinity())); + + // positive normal float value to bf8 and back, check if holds + float pos_float = 0.0000762939f; // 10*2^-17 + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + + // negative smallest normal bf8 value to bf8 and back, check if holds + constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14 + ASSERT_NEAR(neg_min_bf8, type_convert(f8_convert_sr(neg_min_bf8)), 0.0f); + + // positive subnorm float value to bf8 and back, check if holds + constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15 + ASSERT_NEAR( + pos_subnorm_bf8, type_convert(f8_convert_sr(pos_subnorm_bf8)), 0.0f); + + // min subnorm bf8 value to bf8 and back, check if holds + constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16 + ASSERT_NEAR( + min_subnorm_bf8, type_convert(f8_convert_sr(min_subnorm_bf8)), 0.0f); + + // smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16 + constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17 + ASSERT_NEAR(0.0f, + type_convert(f8_convert_sr(less_than_min_subnorm)), + 0.0000152587890625f); + + // convert quiet NaN to bf8_ocp_t and check if it is quiet NaN + const auto bf8_nan = f8_convert_sr(std::numeric_limits::quiet_NaN()); + ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); +} + +TEST(BF8OCP, ConvertFP16Nearest) +{ + // fix the tolerance value + constexpr half_t half_t_tol = 1e-3; + constexpr half_t half_t_zero = 0.0; + + // convert 0 half_t to bfp8 and back, check if holds + ASSERT_NEAR( + half_t_zero, type_convert(f8_convert_rne(half_t_zero)), half_t_zero); + + // convert minimal half_t to bf8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_rne(ck::NumericLimits::Min())), + half_t_tol); + + const auto max_bf8_t_half_t = type_convert(ck::NumericLimits::Max()); + + // convert maximal bf8_ocp_t to half_t and check if equal to bf8 max + ASSERT_NEAR(max_bf8_t_half_t, + type_convert(f8_convert_rne(max_bf8_t_half_t)), + half_t_zero); + + // convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite) + ASSERT_NEAR(max_bf8_t_half_t, + type_convert(f8_convert_rne(ck::NumericLimits::Max())), + half_t_zero); + + // convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ( + ck::NumericLimits::Max(), + f8_convert_rne(type_convert(std::numeric_limits::infinity()))); + + // positive normal bf8 value to bf8 and back, check if holds + constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17 + ASSERT_NEAR( + pos_norm_bf8, type_convert(f8_convert_rne(pos_norm_bf8)), half_t_tol); + + // negative smallest normal bf8 value to bf8 and back, check if holds + constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14 + ASSERT_NEAR( + neg_min_bf8, type_convert(f8_convert_rne(neg_min_bf8)), half_t_zero); + + // positive subnorm bf8 value to bf8 and back, check if holds + constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15 + ASSERT_NEAR(pos_subnorm_bf8, + type_convert(f8_convert_rne(pos_subnorm_bf8)), + half_t_zero); + + // min subnorm bf8 value to bf8 and back, check if holds + constexpr half_t min_subnorm_bf8{-0.0000152587890625f}; //-2^-16 + ASSERT_NEAR(min_subnorm_bf8, + type_convert(f8_convert_rne(min_subnorm_bf8)), + half_t_zero); + + // smaller than min subnorm bf8 value to bf8 must be zero + constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17 + ASSERT_EQ(half_t_zero, type_convert(f8_convert_rne(less_than_min_subnorm))); + + // convert quiet NaN to bf8_ocp_t and check if it is quiet NaN + const auto bf8_nan = f8_convert_rne(ck::NumericLimits::QuietNaN()); + ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); +} + +TEST(BF8OCP, ConvertFP16Stochastic) +{ + // fix the tolerance value + constexpr half_t half_t_tol = 1e-3; + constexpr half_t half_t_zero = 0.0; + constexpr auto min_subnorm_bf8 = 0.0000152587890625f; // 2^-16 + + // convert 0 half_t to bfp8 and back, check if holds + ASSERT_NEAR( + half_t_zero, type_convert(f8_convert_sr(half_t_zero)), half_t_zero); + + // convert minimal half_t (6.103515625e-05) to fp8 and back + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + half_t_zero); + + const auto max_bf8_t_half_t = type_convert(ck::NumericLimits::Max()); + + // convert maximal bf8_ocp_t to half_t and check if equal to bf8 max + ASSERT_NEAR(max_bf8_t_half_t, + type_convert(f8_convert_sr(max_bf8_t_half_t)), + half_t_zero); + + // convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite) + ASSERT_NEAR(max_bf8_t_half_t, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + half_t_zero); + + // convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ( + ck::NumericLimits::Max(), + f8_convert_sr(type_convert(std::numeric_limits::infinity()))); + + // positive normal bf8 value to bf8 and back, check if holds + constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17 + ASSERT_NEAR( + pos_norm_bf8, type_convert(f8_convert_sr(pos_norm_bf8)), half_t_tol); + + // negative smallest normal bf8 value to bf8 and back, check if holds + constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14 + ASSERT_NEAR( + neg_min_bf8, type_convert(f8_convert_sr(neg_min_bf8)), half_t_zero); + + // positive subnorm bf8 value to bf8 and back, check if holds + constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15 + ASSERT_NEAR(pos_subnorm_bf8, + type_convert(f8_convert_sr(pos_subnorm_bf8)), + half_t_zero); + + // min subnorm bf8 value to bf8 and back, check if holds + ASSERT_NEAR(half_t{-min_subnorm_bf8}, + type_convert(f8_convert_sr(half_t{-min_subnorm_bf8})), + half_t_zero); + + // smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16 + constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17 + ASSERT_NEAR(half_t_zero, + type_convert(f8_convert_sr(less_than_min_subnorm)), + half_t{min_subnorm_bf8}); + + // convert quiet NaN to bf8_ocp_t and check if it is quiet NaN + const auto bf8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); + ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); +} diff --git a/test/data_type/test_custom_type.cpp b/test/data_type/test_custom_type.cpp index 1016812544..a8fa9ba4a0 100644 --- a/test/data_type/test_custom_type.cpp +++ b/test/data_type/test_custom_type.cpp @@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape) test_vec.at(num_elem * i + 1)); }); } + +#if CK_USE_OCP_FP8 + +TEST(FP8OCP, TestSize) +{ + static_assert(std::is_same_v, "OCP FP8 is not enabled"); + ASSERT_EQ(sizeof(f8_t), sizeof(ck::fp8_storage_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(FP8OCP, TestAsType) +{ + static_assert(std::is_same_v, "OCP FP8 is not enabled"); + + // test size + std::array test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16}; + constexpr int size = test_vec.size(); + + // reference vector + vector_type right_vec; + + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}), f8_t{0}); }); + + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = ck::type_convert(test_vec.at(i)); + }); + + // copy the vector + vector_type left_vec{right_vec}; + + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}), + ck::type_convert(test_vec.at(i))); + }); + + ck::non_native_vector_base nnvb_f8x2(ck::type_convert(-10.0f)); + ASSERT_EQ(nnvb_f8x2.template AsType()(Number<0>{}), ck::type_convert(-10.0f)); + ASSERT_EQ(nnvb_f8x2.template AsType()(Number<1>{}), ck::type_convert(-10.0f)); +} + +TEST(FP8OCP, TestAsTypeReshape) +{ + static_assert(std::is_same_v, "OCP FP8 is not enabled"); + + // test size + std::array test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16}; + constexpr int size = test_vec.size(); + + // reference vector + vector_type right_vec; + + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}), f8_t{0}); }); + + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = ck::type_convert(test_vec.at(i)); + }); + + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}), + ck::type_convert(test_vec.at(i))); + }); +} + +TEST(BF8OCP, TestSize) +{ + static_assert(std::is_same_v, "OCP BF8 is not enabled"); + ASSERT_EQ(sizeof(bf8_t), sizeof(ck::fp8_storage_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(BF8OCP, TestAsType) +{ + static_assert(std::is_same_v, "OCP BF8 is not enabled"); + + // test size + std::array test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16}; + constexpr int size = test_vec.size(); + + // reference vector + vector_type right_vec; + + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}), bf8_t{0}); }); + + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = ck::type_convert(test_vec.at(i)); + }); + + // copy the vector + vector_type left_vec{right_vec}; + + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}), + ck::type_convert(test_vec.at(i))); + }); + + ck::non_native_vector_base nnvb_bf8x2(ck::type_convert(-10.0f)); + ASSERT_EQ(nnvb_bf8x2.template AsType()(Number<0>{}), ck::type_convert(-10.0f)); + ASSERT_EQ(nnvb_bf8x2.template AsType()(Number<1>{}), ck::type_convert(-10.0f)); +} + +TEST(BF8OCP, TestAsTypeReshape) +{ + static_assert(std::is_same_v, "OCP BF8 is not enabled"); + + // test size + std::array test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16}; + constexpr int size = test_vec.size(); + + // reference vector + vector_type right_vec; + + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}), bf8_t{0}); }); + + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = ck::type_convert(test_vec.at(i)); + }); + + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}), + ck::type_convert(test_vec.at(i))); + }); +} + +#endif diff --git a/test/data_type/test_fp8.cpp b/test/data_type/test_fp8_fnuz.cpp similarity index 52% rename from test/data_type/test_fp8.cpp rename to test/data_type/test_fp8_fnuz.cpp index 25d9d9d2fb..c2ec6dad94 100644 --- a/test/data_type/test_fp8.cpp +++ b/test/data_type/test_fp8_fnuz.cpp @@ -7,154 +7,171 @@ using ck::f8_convert_rne; using ck::f8_convert_sr; -using ck::f8_t; +using ck::f8_fnuz_t; using ck::half_t; using ck::type_convert; -TEST(FP8, NumericLimits) +TEST(FP8FNUZ, NumericLimits) { // constants given for negative zero nan mode - EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x08)); - EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); - EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); - EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); + EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x08)); + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); + EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); } -TEST(FP8, ConvertFP32Nearest) +TEST(FP8FNUZ, ConvertFP32Nearest) { // fix the tolerance value float abs_tol = 1e-6; // convert 0 float to fp8 and back, check if holds - ASSERT_NEAR(0.0f, type_convert(f8_convert_rne(0.0f)), abs_tol); + ASSERT_NEAR(0.0f, type_convert(f8_convert_rne(0.0f)), abs_tol); // don't run the next test on gfx11 devices #ifndef CK_SKIP_FLAKY_F8_TEST // convert minimal float to fp8 and back, check if holds ASSERT_NEAR(std::numeric_limits::min(), - type_convert(f8_convert_rne(std::numeric_limits::min())), + type_convert(f8_convert_rne(std::numeric_limits::min())), abs_tol); #endif - // convert maximal f8_t to float and check if equal to 240.0 - ASSERT_NEAR(240.0f, type_convert(f8_convert_rne(240.0f)), abs_tol); - // convert maximal float to fp8 and back, check if clipped to 240.0 - ASSERT_NEAR(240.0f, - type_convert(f8_convert_rne(std::numeric_limits::max())), + + const auto max_f8_t_float = type_convert(ck::NumericLimits::Max()); + // convert maximal f8_fnuz_t to float and check if equal to fp8 max + ASSERT_NEAR( + max_f8_t_float, type_convert(f8_convert_rne(max_f8_t_float)), abs_tol); + + // XXX: FNUZ f8_convert_rne behavior is inconsistent. + // Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to + // fp8 qNAN (no saturation). + + // convert maximal float to fp8 and back, check if clipped to fp8 max + ASSERT_NEAR(max_f8_t_float, + type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); - // convert inf float to f8_t and check if it is qNan - ASSERT_NEAR(type_convert(0x80), - f8_convert_rne(std::numeric_limits::infinity()), + // convert inf float to f8_fnuz_t and check if it is qNan + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity()), abs_tol); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); // negative norm float value to fp8 and back, check if holds float neg_float = -0.015625f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); // positive subnorm float value to fp8 and back, check if holds pos_float = 0.00390625f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); // negative subnorm float value to fp8 and back, check if holds neg_float = -0.001953125f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), abs_tol); } -TEST(FP8, ConvertFP32Stochastic) +TEST(FP8FNUZ, ConvertFP32Stochastic) { // fix the tolerance value float abs_tol = 1e-6; // convert 0 float to fp8 and back, check if holds - ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); // convert minimal float to fp8 and back, check if holds ASSERT_NEAR(std::numeric_limits::min(), - type_convert(f8_convert_sr(std::numeric_limits::min())), + type_convert(f8_convert_sr(std::numeric_limits::min())), abs_tol); - // convert maximal f8_t to float and check if equal to 240.0 - ASSERT_NEAR(240.0f, type_convert(f8_convert_sr(240.0f)), abs_tol); - // convert maximal float to fp8 and back, check if clipped to 240.0 - ASSERT_NEAR(240.0f, - type_convert(f8_convert_sr(std::numeric_limits::max())), + + const auto max_f8_t_float = type_convert(ck::NumericLimits::Max()); + // convert maximal f8_fnuz_t to float and check if equal to fp8 max + ASSERT_NEAR( + max_f8_t_float, type_convert(f8_convert_sr(max_f8_t_float)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to fp8 max + ASSERT_NEAR(max_f8_t_float, + type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); - // convert inf float to f8_t and check if it is qNan - ASSERT_NEAR(type_convert(0x80), - f8_convert_sr(std::numeric_limits::infinity()), + // convert inf float to f8_fnuz_t and check if it is qNan + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity()), abs_tol); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); // negative norm float value to fp8 and back, check if holds float neg_float = -0.015625f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); // positive subnorm float value to fp8 and back, check if holds pos_float = 0.00390625f; - ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); // negative subnorm float value to fp8 and back, check if holds neg_float = -0.001953125f; - ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); } -TEST(FP8, ConvertFP16Nearest) +TEST(FP8FNUZ, ConvertFP16Nearest) { // fix the tolerance value float abs_tol = 1e-3; // convert 0 fp16 to fp8 and back, check if holds - ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_rne(half_t{0.0})), abs_tol); + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_rne(half_t{0.0})), abs_tol); // convert minimal fp16 to fp8 and back, check if holds ASSERT_NEAR(ck::NumericLimits::Min(), - type_convert(f8_convert_rne(ck::NumericLimits::Min())), + type_convert(f8_convert_rne(ck::NumericLimits::Min())), abs_tol); - // convert maximal f8_t to fp16 and check if equal to 240.0 - ASSERT_NEAR(half_t{240.0}, type_convert(f8_convert_rne(half_t{240.0})), abs_tol); - // convert maximal fp16 to fp8 and back, check if clipped to 240.0 - ASSERT_NEAR(half_t{240.0}, - type_convert(f8_convert_rne(ck::NumericLimits::Max())), + + const auto max_f8_t_half = type_convert(ck::NumericLimits::Max()); + // convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max + ASSERT_NEAR( + max_f8_t_half, type_convert(f8_convert_rne(max_f8_t_half)), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to fp8 max + ASSERT_NEAR(max_f8_t_half, + type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); - // convert QuietNaN fp16 to f8_t and check if it is QuietNaN - ASSERT_NEAR(type_convert(0x80), - f8_convert_rne(ck::NumericLimits::QuietNaN()), + // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN()), abs_tol); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); // negative norm fp16 value to fp8 and back, check if holds half_t neg_half = half_t{-0.015625}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); // positive subnorm fp16 value to fp8 and back, check if holds pos_half = half_t{0.00390625}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); // negative subnorm fp16 value to fp8 and back, check if holds neg_half = half_t{-0.001953125}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_rne(neg_half)), abs_tol); } -TEST(FP8, ConvertFP16Stochastic) +TEST(FP8FNUZ, ConvertFP16Stochastic) { // fix the tolerance value float abs_tol = 1e-3; // convert 0 fp16 to fp8 and back, check if holds - ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); // convert minimal fp16 to fp8 and back, check if holds ASSERT_NEAR(ck::NumericLimits::Min(), - type_convert(f8_convert_sr(ck::NumericLimits::Min())), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), abs_tol); - // convert maximal f8_t to fp16 and check if equal to 240.0 - ASSERT_NEAR(half_t{240.0}, type_convert(f8_convert_sr(half_t{240.0})), abs_tol); - // convert maximal fp16 to fp8 and back, check if clipped to 240.0 - ASSERT_NEAR(half_t{240.0}, - type_convert(f8_convert_sr(ck::NumericLimits::Max())), + + const auto max_f8_t_half = type_convert(ck::NumericLimits::Max()); + // convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max + ASSERT_NEAR( + max_f8_t_half, type_convert(f8_convert_sr(max_f8_t_half)), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to fp8 max + ASSERT_NEAR(max_f8_t_half, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); - // convert QuietNaN fp16 to f8_t and check if it is QuietNaN - ASSERT_NEAR(type_convert(0x80), - f8_convert_sr(ck::NumericLimits::QuietNaN()), + // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN + ASSERT_NEAR(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); // negative norm fp16 value to fp8 and back, check if holds half_t neg_half = half_t{-0.015625}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); // positive subnorm fp16 value to fp8 and back, check if holds pos_half = half_t{0.00390625}; - ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); // negative subnorm fp16 value to fp8 and back, check if holds neg_half = half_t{-0.001953125}; - ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); } diff --git a/test/data_type/test_fp8_ocp.cpp b/test/data_type/test_fp8_ocp.cpp new file mode 100644 index 0000000000..a8077f1bdf --- /dev/null +++ b/test/data_type/test_fp8_ocp.cpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::f8_convert_rne; +using ck::f8_convert_sr; +using ck::f8_ocp_t; +using ck::half_t; +using ck::type_convert; + +TEST(FP8OCP, NumericLimits) +{ + // constants given for OCP FP8 + EXPECT_EQ(ck::NumericLimits::Min(), + type_convert(0x08)); // 0b00001000 = 2^-6 + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7E)); // 0b01111110 = 448 + EXPECT_EQ(ck::NumericLimits::Lowest(), + type_convert(0xFE)); // 0b11111110 = -448 + EXPECT_EQ(ck::NumericLimits::QuietNaN().data, + type_convert(0x7F).data); // 0b01111111 + EXPECT_FALSE(ck::NumericLimits::QuietNaN() == + ck::NumericLimits::QuietNaN()); +} + +TEST(FP8OCP, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_rne(0.0f)), 0.0f); + + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_rne(std::numeric_limits::min())), + abs_tol); + + const auto max_f8_t_float = type_convert(ck::NumericLimits::Max()); + + // convert maximal f8_ocp_t to float and check if equal to fp8 max + ASSERT_NEAR( + max_f8_t_float, type_convert(f8_convert_rne(max_f8_t_float)), 0.0f); + + // convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite) + ASSERT_NEAR(max_f8_t_float, + type_convert(f8_convert_rne(std::numeric_limits::max())), + 0.0f); + + // convert float infinity to f8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ(ck::NumericLimits::Max(), + f8_convert_rne(std::numeric_limits::infinity())); + + // positive norm float value to fp8 and back, check if holds + float pos_float = 0.017578125f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + + // smallest normal fp8 value to fp8 and back, check if holds + float neg_float = -0.015625f; //-2^-6 + ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), 0.0f); + + // positive subnorm float value to fp8 and back, check if holds + pos_float = 0.00390625f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); + + // min subnorm fp8 value to fp8 and back, check if holds + neg_float = -0.001953125f; //-2^-9 + ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), 0.0f); + + // smaller than min subnorm fp8 value to fp8 must be zero + auto less_than_min_subnorm = 0.0009765625f; // 2^-10 + ASSERT_EQ(0.0f, type_convert(f8_convert_rne(less_than_min_subnorm))); + + // convert quiet NaN to f8_ocp_t and check if it is quiet NaN + auto f8_nan = f8_convert_rne(std::numeric_limits::quiet_NaN()); + ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f); +} + +TEST(FP8OCP, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), 0.0f); + + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + + const auto max_f8_t_float = type_convert(ck::NumericLimits::Max()); + + // convert maximal f8_ocp_t to float and check if equal to fp8 max + ASSERT_NEAR(max_f8_t_float, type_convert(f8_convert_sr(max_f8_t_float)), 0.0f); + + // convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite) + ASSERT_NEAR(max_f8_t_float, + type_convert(f8_convert_sr(std::numeric_limits::max())), + 0.0f); + + // convert float infinity to f8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ(ck::NumericLimits::Max(), + f8_convert_sr(std::numeric_limits::infinity())); + + // positive norm float value to fp8 and back, check if holds + float pos_float = 0.017578125f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + + // smallest normal fp8 value to fp8 and back, check if holds + float neg_float = -0.015625f; //-2^-6 + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), 0.0f); + + // positive subnorm float value to fp8 and back, check if holds + pos_float = 0.00390625f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + + // min subnorm fp8 value to fp8 and back, check if holds + constexpr auto min_subnorm_fp8 = -0.001953125f; //-2^-9 + ASSERT_NEAR( + min_subnorm_fp8, type_convert(f8_convert_sr(min_subnorm_fp8)), 0.0f); + + // smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9 + auto less_than_min_subnorm = 0.0009765625f; // 2^-10 + ASSERT_NEAR( + 0.0f, type_convert(f8_convert_sr(less_than_min_subnorm)), 0.001953125f); + + // convert quiet NaN to f8_ocp_t and check if it is quiet NaN + auto f8_nan = f8_convert_sr(std::numeric_limits::quiet_NaN()); + ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f); +} + +TEST(FP8OCP, ConvertFP16Nearest) +{ + // fix the tolerance value + constexpr half_t half_t_tol = 1e-3; + constexpr half_t half_t_zero = 0.0; + // convert 0 half_t to fp8 and back, check if holds + ASSERT_NEAR( + half_t_zero, type_convert(f8_convert_rne(half_t_zero)), half_t_zero); + + // convert minimal half_t to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_rne(ck::NumericLimits::Min())), + half_t_tol); + const auto max_f8_t_half_t = type_convert(ck::NumericLimits::Max()); + + // convert maximal f8_ocp_t to half_t and check if equal to fp8 max + ASSERT_NEAR(max_f8_t_half_t, + type_convert(f8_convert_rne(max_f8_t_half_t)), + half_t_zero); + + // convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite) + ASSERT_NEAR(max_f8_t_half_t, + type_convert(f8_convert_rne(ck::NumericLimits::Max())), + half_t_zero); + + // convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ( + ck::NumericLimits::Max(), + f8_convert_rne(type_convert(std::numeric_limits::infinity()))); + + // positive norm half_t value to fp8 and back, check if holds + half_t pos_half_t{0.017578125f}; + ASSERT_NEAR(pos_half_t, type_convert(f8_convert_rne(pos_half_t)), half_t_tol); + + // smallest normal fp8 value to fp8 and back, check if holds + half_t neg_half_t{-0.015625f}; //-2^-6 + ASSERT_NEAR( + neg_half_t, type_convert(f8_convert_rne(neg_half_t)), half_t_zero); + + // positive subnorm half_t value to fp8 and back, check if holds + pos_half_t = half_t{0.00390625f}; + ASSERT_NEAR(pos_half_t, type_convert(f8_convert_rne(pos_half_t)), half_t_tol); + + // min subnorm fp8 value to fp8 and back, check if holds + neg_half_t = half_t{-0.001953125f}; //-2^-9 + ASSERT_NEAR( + neg_half_t, type_convert(f8_convert_rne(neg_half_t)), half_t_zero); + + // smaller than min subnorm fp8 value to fp8 must be zero + auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10 + ASSERT_EQ(half_t_zero, type_convert(f8_convert_rne(less_than_min_subnorm))); + + // convert quiet NaN to f8_ocp_t and check if it is quiet NaN + auto f8_nan = f8_convert_rne(ck::NumericLimits::QuietNaN()); + ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data)); +} + +TEST(FP8OCP, ConvertFP16Stochastic) +{ + // fix the tolerance value + constexpr half_t half_t_tol = 1e-3; + constexpr half_t half_t_zero = 0.0; + constexpr auto min_subnorm_fp8 = 0.001953125f; // 2^-9 + + // convert 0 half_t to fp8 and back, check if holds + ASSERT_NEAR( + half_t_zero, type_convert(f8_convert_sr(half_t_zero)), half_t_zero); + + // convert minimal half_t (6.103515625e-05) to fp8 and back + // alternates between 0 and 2^-9 (0.001953125) + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + type_convert(min_subnorm_fp8)); + + const auto max_f8_t_half_t = type_convert(ck::NumericLimits::Max()); + + // convert maximal f8_ocp_t to half_t and check if equal to fp8 max + ASSERT_NEAR(max_f8_t_half_t, + type_convert(f8_convert_sr(max_f8_t_half_t)), + half_t_zero); + + // convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite) + ASSERT_NEAR(max_f8_t_half_t, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + half_t_zero); + + // convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite) + ASSERT_EQ( + ck::NumericLimits::Max(), + f8_convert_sr(type_convert(std::numeric_limits::infinity()))); + + // positive norm half_t value to fp8 and back, check if holds + half_t pos_half_t{0.017578125f}; + ASSERT_NEAR(pos_half_t, type_convert(f8_convert_sr(pos_half_t)), half_t_tol); + + // smallest normal fp8 value to fp8 and back, check if holds + half_t neg_half_t{-0.015625f}; //-2^-6 + ASSERT_NEAR(neg_half_t, type_convert(f8_convert_sr(neg_half_t)), half_t_zero); + + // positive subnorm half_t value to fp8 and back, check if holds + pos_half_t = half_t{0.00390625f}; + ASSERT_NEAR(pos_half_t, type_convert(f8_convert_sr(pos_half_t)), half_t_tol); + + // min subnorm fp8 value to fp8 and back, check if holds + neg_half_t = half_t{-min_subnorm_fp8}; //-2^-9 + ASSERT_NEAR(neg_half_t, type_convert(f8_convert_sr(neg_half_t)), half_t_zero); + + // smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9 + auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10 + ASSERT_NEAR( + type_convert(half_t_zero), + type_convert(type_convert(f8_convert_sr(less_than_min_subnorm))), + min_subnorm_fp8); + + // convert quiet NaN to f8_ocp_t and check if it is quiet NaN + auto f8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); + ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data)); +} diff --git a/test/pool/test_avg_pool2d_fwd.cpp b/test/pool/test_avg_pool2d_fwd.cpp index 8dbb37b84f..b5e733419a 100644 --- a/test/pool/test_avg_pool2d_fwd.cpp +++ b/test/pool/test_avg_pool2d_fwd.cpp @@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types); TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types); TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types); -TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); } +TYPED_TEST(AvgPool2D_F32, AvgPool2D_F32_Test) { this->Run(); } TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); } TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); } TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); } diff --git a/test/pool/test_max_pool2d_fwd.cpp b/test/pool/test_max_pool2d_fwd.cpp index 80ca47407b..2179242754 100644 --- a/test/pool/test_max_pool2d_fwd.cpp +++ b/test/pool/test_max_pool2d_fwd.cpp @@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types); TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types); TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types); -TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); } +TYPED_TEST(MaxPool2D_F32, MaxPool2D_F32_Test) { this->Run(); } TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); } TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); } TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }