mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
OCP FP8 support for gfx12. (#1710)
* (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * initial enablement of gfx950 * fix clang format * disable examples 31 and 41 int8 on gfx950 * initial navi4x enablement * remove extra endif * enabled dl_gemm * update s_barrier and s_waitcnt for gfx12 * fix the gfx12 assembly syntax * fixed block_sync_lds * add support for more dl kernels on navi4 * add wmma * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * fixed block_sync_lds * merge navi3_ref * update self-attention and cross-attention * fix a typo of name * fixed layout * debugging * Add arch limiter for fp8 gemm * fixed wmma * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * fixed c_output * re-enable the gemm_xdl_fp16_fp8 on MI100/200 * fixed gfx12 * fixed * fixed * seperate gfx12 blockwise_gemm * fixed * enable fwd conv on navi4x * enable gridwise * enabled gemm * fixed merge * remove empty example fold * fixed conflicts * some small changes * Update cmake-ck-dev.sh * Update cmake-ck-dev.sh * enabled other types * fixed register loads * test fa * enable gfx12 * clean up * enable some instances on gfx12 * add gfx1201 macro in amd_wmma header * fix clang format * enable batched_gemm_softmax_gemm_perm_wmma for gfx12 * disable instances with blocksize=256 in attention examples * debuggging * debug * fixed lds_enabled * debugging * Fix and add limit to skiplds feature * Enable skipLds feature and fix compilation bugs * add ck_tile definitions for gfx12 * fix clang format and test/wmma_op * updage instances cmake for gfx12 * disable the test_wmma_op on gfx12 * fix the builds for gfx950 * add gfx12 and gfx950 to default target list * clean-up cmake file * Initial introduction of OFP8 data types. * Renamed FP8 and BF8 tests into FP8_FNUZ and BF8_FNUZ. * Implementation of ConvertFP32Nearest in test_fp8_ocp. * Remove dependence on possibly undeclared alias. * Implement FP8OCP test for stochastic rounding mode. * Implement FP8OCP tests for half_t type conversions. * enable bf16 atomic add on gfx950 * Implement ConvertFP32Nearest test. * Implement ConvertFP32Stochastic test. * Implement ConvertFP16Nearest and ConvertFP16Stochastic tests. * Refactoring. Move FP8 definitions into a separate header file. * Enable easy switching between architectures. * Fix compilation error for gfx942 architecture. * only builf gfx950 branch for gfx950 target by default * Enable OCP build of example_gemm_xdl_fp8. * Fix formatting. * fix the build logic for gfx950 * Improve GEMM example verbosity. * Add constexpr where applicable. * fix the logic of enabling XDL and WMMA instances * Improve GEMM example verbosity. * Enable build of example_gemm_xdl_fp8_bf8 test. * Fix tests for gfx1101 architecture. * Build DPP examples only on gfx103 and gfx11 architectures. * Optionaly run either CPU or GPU verifications with GEMM examples. * Extend GeneratorTensor_Sequential to produce values of prescribed data types. * Add missing constructor. * Improve infrastructure for OFP8 data type support. * BUGFIX. Should not use FP8 as Compute/Accum data type. * Add custom target for grouped_convnd_bwd_weight tests. * Can build `tests` target on gfx950. * Bugfixes on gfx1101 architecture. * Fix dependencies. * Provide single point of truth for FP8 INF and NAN checks * Prevent instantiation of operators that are not supported by FP8 data types * Add FP8 type selection into client_axample CMakeLists.txt * Prevent sccache server from shutting down during build * Fix test success reporting logic * Change default verification method to CPU. GPU verification takes too much time to complete on the emulator. * Make sure all tests and examples are built for gfx950 * Facilitate testing of FP8 data types on the emulator * Introduce two new tensor generators * Enable instances built for gfx94 to be built on gfx950 * Verify 35_splitk_gemm on floating point numbers. splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved. * Verify 04_gemm_add_add_fastgelu on floating point numbers * Verify 20_grouped_conv_bwd_weight on floating point numbers * Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers * Verify more tests on floating point data * Fix data types and improve testing verbocity. * Upgrade to NPI 573 build docker. * Skip on gemm_universal tests. The tests take too long to complete on the emulator. Need to see if it is possible to reduce the scope of the testing to just FP8 data types. * Fix gfx1101 build * Document test availability * Re-enable fp8 gemms for gfx94/95 * Cherry-pick GEMM Universal tests for FP8 data types * Cleanup * CK_USE_GFX94 has already been set on this branch * Address formatting issues and leftovers * Make fail/pass logic consistent within 01_gemm folder Removed multiple negations in fail/pass logic to propagate `true` as the success indicator. * Fix GPU verification reporting logic. * Update year in copyright notice. * Cleanup * Use `enum class` instead of `enum` * Remove set_property for FP8 tests * Narrowing the scope of PR to OCP FP8 enablement only * Add tests for OCP FP8 vector_type storage * Enable gemm kernel on all gfx9 architectures (#227) * clean-up * Implement `non_native_vector_base` with `ext_vector_type` array. (#232) * Enable support of 1, 2, 4, and 8-byte custom types in CK. * Fix pool tests for OCP FP8 data type * fix jenkins file * restore cron trigger --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jing Zhang <jizhan@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
This commit is contained in:
@@ -185,13 +185,22 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
message("Enabling FP8 gemms in ckProfiler")
|
||||
message("Enabling FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message("Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
add_definitions(-DCK_USE_OCP_FP8)
|
||||
set(CK_USE_OCP_FP8 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
add_definitions(-DCK_USE_FNUZ_FP8)
|
||||
set(CK_USE_FNUZ_FP8 "ON")
|
||||
endif()
|
||||
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
|
||||
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
|
||||
@@ -56,6 +56,14 @@ if (GPU_TARGETS)
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx12")
|
||||
add_definitions(-DCK_USE_OCP_FP8)
|
||||
set(CK_USE_OCP_FP8 "ON")
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
|
||||
add_definitions(-DCK_USE_FNUZ_FP8)
|
||||
set(CK_USE_FNUZ_FP8 "ON")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
|
||||
@@ -76,7 +76,7 @@ struct ProblemSizeSplitK final
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
|
||||
int do_verification = 3;
|
||||
int do_verification = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
|
||||
break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
|
||||
@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDMatrices; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<D0DataType, 1>{});
|
||||
}
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;
|
||||
|
||||
@@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -175,8 +175,8 @@ int main(int argc, char* argv[])
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
|
||||
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -118,7 +118,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -153,7 +153,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -178,7 +178,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -152,7 +152,7 @@ int run(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
@@ -173,7 +173,7 @@ int run(int argc, char* argv[])
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -377,7 +377,7 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -41,7 +41,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParams \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B1DataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
|
||||
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A1DataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
|
||||
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
|
||||
a1_device_buf.ToDevice(a1_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
|
||||
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any DPP examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp")
|
||||
message("removing dpp example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
|
||||
@@ -326,7 +326,7 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
|
||||
|
||||
void SetZero() { ck::ranges::fill<T>(mData, 0); }
|
||||
void SetZero() { ck::ranges::fill<T>(mData, T{0}); }
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bhalf_t operator()(Is...)
|
||||
ck::half_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::half_t>(value);
|
||||
}
|
||||
@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bhalf_t operator()(Is...)
|
||||
ck::f8_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::f8_t>(value);
|
||||
}
|
||||
@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t Dim>
|
||||
/**
|
||||
* @brief Is used to generate sequential values based on the specified dimension.
|
||||
*
|
||||
* @tparam T The type of the tensor values.
|
||||
* @tparam Dim The specific dimension used for generation.
|
||||
*
|
||||
* GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor:
|
||||
*
|
||||
* 0 1 2
|
||||
* 0 1 2
|
||||
* 0 1 2
|
||||
*
|
||||
* Essentially, the values generated are logical coordinates of the generated element that
|
||||
* correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column
|
||||
* indices.
|
||||
*
|
||||
*/
|
||||
template <typename T, ck::index_t Dim>
|
||||
struct GeneratorTensor_Sequential
|
||||
{
|
||||
template <typename... Ts>
|
||||
float operator()(Ts... Xs) const
|
||||
T operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
return dims[Dim];
|
||||
|
||||
float tmp = dims[Dim];
|
||||
return ck::type_convert<T>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -111,8 +111,7 @@ __global__ void
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
@@ -38,8 +38,7 @@ __global__ void
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
|
||||
@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, fp8_storage_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
|
||||
#else
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0)};
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
}
|
||||
@@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0)};
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||
}
|
||||
|
||||
988
include/ck/utility/amd_ck_fp8.hpp
Normal file
988
include/ck/utility/amd_ck_fp8.hpp
Normal file
@@ -0,0 +1,988 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
#ifdef CK_USE_FNUZ_FP8
|
||||
#define CK_USE_FNUZ_FP8 1
|
||||
#else
|
||||
#define CK_USE_FNUZ_FP8 0
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_OCP_FP8
|
||||
#define CK_USE_OCP_FP8 1
|
||||
#else
|
||||
#define CK_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
using f8_fnuz_t = _BitInt(8);
|
||||
using bf8_fnuz_t = unsigned _BitInt(8);
|
||||
|
||||
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
|
||||
defined(__gfx1201__)) && \
|
||||
__HIP_DEVICE_COMPILE__
|
||||
#define CK_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_FP8_CVT_FAST_PATH 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 0
|
||||
#endif
|
||||
|
||||
typedef unsigned char fp8_storage_t;
|
||||
|
||||
/**
|
||||
* \brief Describes FP8 interpretation
|
||||
*/
|
||||
enum class ck_fp8_interpretation_t
|
||||
{
|
||||
CK_E4M3_OCP = 0, // OCP E4M3
|
||||
CK_E5M2_OCP = 1, // OCP E5M2
|
||||
CK_E4M3_FNUZ = 2, // FP8
|
||||
CK_E5M2_FNUZ = 3, // BF8
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Describes saturation behavior
|
||||
*/
|
||||
enum class ck_saturation_t
|
||||
{
|
||||
CK_NOSAT = 0, // No saturation - replace with NaN or Inf
|
||||
CK_SATFINITE = 1, // Saturate to finite
|
||||
};
|
||||
|
||||
namespace fp8_impl {
|
||||
|
||||
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
|
||||
{
|
||||
return static_cast<unsigned char>(a) == 0x80;
|
||||
}
|
||||
__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
|
||||
{
|
||||
return static_cast<unsigned char>(a) == 0x80;
|
||||
}
|
||||
|
||||
__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
|
||||
{
|
||||
return (a & 0x7f) == 0x7f;
|
||||
}
|
||||
__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
|
||||
{
|
||||
return (a & 0x7f) > 0x7c;
|
||||
}
|
||||
|
||||
// The conversion function is from rocblas
|
||||
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
|
||||
// This has been modified to handle double types as well
|
||||
template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
|
||||
__host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
|
||||
{
|
||||
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
|
||||
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
|
||||
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
|
||||
static_assert(is_half || is_float || is_double, "only half, float and double are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
|
||||
if constexpr(is_half)
|
||||
{
|
||||
const unsigned short int ihInf = 0x7C00;
|
||||
const unsigned short int ihNegInf = 0xFC00;
|
||||
const unsigned short int ihNaN = 0x7C01;
|
||||
const unsigned short int ihNeg0 = 0x8000;
|
||||
/* Max number in e5m2 57344*/
|
||||
const unsigned short int ifmax = 0x7B00;
|
||||
const unsigned short int ifmin = 0xFB00;
|
||||
|
||||
fInf = bit_cast<_Float16>(ihInf);
|
||||
fNegInf = bit_cast<_Float16>(ihNegInf);
|
||||
fNaN = bit_cast<_Float16>(ihNaN);
|
||||
fNeg0 = bit_cast<_Float16>(ihNeg0);
|
||||
fmax = bit_cast<_Float16>(ifmax);
|
||||
fmin = bit_cast<_Float16>(ifmin);
|
||||
}
|
||||
else if constexpr(is_float)
|
||||
{
|
||||
const unsigned int ifInf = 0x7F800000;
|
||||
const unsigned int ifNegInf = 0xFF800000;
|
||||
const unsigned int ifNaN = 0x7F800001;
|
||||
const unsigned int ifNeg0 = 0x80000000;
|
||||
/* Max number in e5m2 57344*/
|
||||
const unsigned int ifmax = 0x47600000;
|
||||
const unsigned int ifmin = 0xC7600000;
|
||||
|
||||
fInf = bit_cast<float>(ifInf);
|
||||
fNegInf = bit_cast<float>(ifNegInf);
|
||||
fNaN = bit_cast<float>(ifNaN);
|
||||
fNeg0 = bit_cast<float>(ifNeg0);
|
||||
fmax = bit_cast<float>(ifmax);
|
||||
fmin = bit_cast<float>(ifmin);
|
||||
}
|
||||
else if constexpr(is_double)
|
||||
{
|
||||
const unsigned long long ifInf = 0x7FF0000000000000ull;
|
||||
const unsigned long long ifNegInf = 0xFFF0000000000000ull;
|
||||
const unsigned long long ifNaN = 0x7FF0000000000001ull;
|
||||
const unsigned long long ifNeg0 = 0x8000000000000000ull;
|
||||
/* Max number in e5m2 57344*/
|
||||
const unsigned long long ifmax = 0x40EC000000000000ull;
|
||||
const unsigned long long ifmin = 0xC0EC000000000000ull;
|
||||
|
||||
fInf = bit_cast<double>(ifInf);
|
||||
fNegInf = bit_cast<double>(ifNegInf);
|
||||
fNaN = bit_cast<double>(ifNaN);
|
||||
fNeg0 = bit_cast<double>(ifNeg0);
|
||||
fmax = bit_cast<double>(ifmax);
|
||||
fmin = bit_cast<double>(ifmin);
|
||||
}
|
||||
|
||||
if(x == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned long long sign = x >> 7;
|
||||
unsigned long long mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
if(x == 0x80)
|
||||
{
|
||||
return fNaN;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x == 0x80)
|
||||
{
|
||||
return fNeg0;
|
||||
}
|
||||
if constexpr(we == 4)
|
||||
{ // e4m3
|
||||
if((x & 0x7F) == 0x7F)
|
||||
{
|
||||
return fNaN;
|
||||
}
|
||||
}
|
||||
else if((x & 0x7C) == 0x7C)
|
||||
{ // e5m2
|
||||
if((x & 0x3) == 0)
|
||||
{
|
||||
if constexpr(clip)
|
||||
{
|
||||
return sign ? fmin : fmax;
|
||||
}
|
||||
return sign ? fNegInf : fInf;
|
||||
}
|
||||
return fNaN;
|
||||
}
|
||||
}
|
||||
|
||||
typename __hip_internal::conditional<
|
||||
sizeof(T) == 2,
|
||||
unsigned short int,
|
||||
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::
|
||||
type>::type retval;
|
||||
|
||||
if constexpr(we == 5 && is_half && !is_fnuz)
|
||||
{
|
||||
retval = x << 8;
|
||||
return bit_cast<T>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if(exponent == 0)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + __clz(mantissa) - (32 - wm);
|
||||
#else
|
||||
int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
|
||||
#endif
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1ull << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if(exponent <= 0)
|
||||
{
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if constexpr(sizeof(T) == 2)
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
else if constexpr(sizeof(T) == 4)
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
else
|
||||
retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
|
||||
|
||||
return bit_cast<T>(retval);
|
||||
}
|
||||
|
||||
#if CK_FP8_CVT_FAST_PATH
|
||||
template <ck_fp8_interpretation_t interpret>
|
||||
static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
|
||||
{
|
||||
union
|
||||
{
|
||||
unsigned int i32val;
|
||||
unsigned char i8val[4];
|
||||
} val;
|
||||
val.i8val[0] = v;
|
||||
|
||||
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
|
||||
interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
|
||||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
|
||||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
|
||||
"Only FNUZ and OCP interpretations are supported");
|
||||
|
||||
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
|
||||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
|
||||
{
|
||||
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <ck_fp8_interpretation_t interpret>
|
||||
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
|
||||
{
|
||||
const auto i16val = bit_cast<uint16_t>(v);
|
||||
|
||||
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
|
||||
interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
|
||||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
|
||||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
|
||||
"Only FNUZ and OCP interpretations are supported");
|
||||
|
||||
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
|
||||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
|
||||
{
|
||||
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace fp8_impl
|
||||
|
||||
struct f8_ocp_t
|
||||
{
|
||||
using data_type = fp8_storage_t;
|
||||
data_type data;
|
||||
|
||||
static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
|
||||
static constexpr ck_fp8_interpretation_t default_interpret =
|
||||
ck_fp8_interpretation_t::CK_E4M3_OCP;
|
||||
|
||||
static constexpr unsigned int we = 4; // exponent width
|
||||
static constexpr unsigned int wm = 3; // mantissa width
|
||||
|
||||
__host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
|
||||
{
|
||||
return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
__host__ __device__ explicit operator float() const
|
||||
#else
|
||||
__host__ explicit operator float() const
|
||||
#endif
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, wm, we, false>(
|
||||
this->data); // XXX: clip==false must be consistent with operator _Float16
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
__host__ __device__ explicit operator _Float16() const
|
||||
#else
|
||||
__host__ explicit operator _Float16() const
|
||||
#endif
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
|
||||
this->data); // XXX: clip==false must be consistent with operator float
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct bf8_ocp_t
|
||||
{
|
||||
using data_type = fp8_storage_t;
|
||||
data_type data;
|
||||
|
||||
static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
|
||||
static constexpr ck_fp8_interpretation_t default_interpret =
|
||||
ck_fp8_interpretation_t::CK_E5M2_OCP;
|
||||
|
||||
static constexpr unsigned int we = 5; // exponent width
|
||||
static constexpr unsigned int wm = 2; // mantissa width
|
||||
|
||||
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
|
||||
{
|
||||
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
__host__ __device__ explicit operator float() const
|
||||
|
||||
#else
|
||||
__host__ explicit operator float() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, wm, we, false>(
|
||||
this->data); // XXX: clip==false must be consistent with operator _Float16
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
__host__ __device__ explicit operator _Float16() const
|
||||
#else
|
||||
__host__ explicit operator _Float16() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
|
||||
this->data); // XXX: clip==false must be consistent with operator float
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ static inline constexpr bool fp8_is_nan(T);
|
||||
|
||||
template <>
|
||||
__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
|
||||
{
|
||||
return fp8_impl::ocp_f8_is_nan(a.data);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
|
||||
{
|
||||
return fp8_impl::ocp_bf8_is_nan(a.data);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
|
||||
{
|
||||
return fp8_impl::fnuz_f8_is_nan(a);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
|
||||
{
|
||||
return fp8_impl::fnuz_bf8_is_nan(a);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<std::is_same_v<T, bf8_ocp_t> || std::is_same_v<T, f8_ocp_t> ||
|
||||
std::is_same_v<T, bf8_fnuz_t> || std::is_same_v<T, f8_fnuz_t>,
|
||||
bool> = true>
|
||||
__host__ __device__ static inline constexpr bool fp8_is_inf(T)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
|
||||
{
|
||||
return (a.data & 0x7f) == 0x7c;
|
||||
}
|
||||
|
||||
namespace fp8_impl {
|
||||
|
||||
// Assertions to check for supported conversion types
|
||||
#define __assert_ocp_support(interp) \
|
||||
{ \
|
||||
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
|
||||
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
|
||||
{ \
|
||||
__hip_assert(false && "type is unsupported by current target device"); \
|
||||
} \
|
||||
}
|
||||
#define __assert_fnuz_support(interp) \
|
||||
{ \
|
||||
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
|
||||
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
|
||||
{ \
|
||||
__hip_assert(false && "type is unsupported by current target device"); \
|
||||
} \
|
||||
}
|
||||
|
||||
__host__ __device__ static inline void
|
||||
__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
|
||||
#if CK_USE_OCP_FP8
|
||||
__assert_ocp_support(interp);
|
||||
#endif
|
||||
#if CK_USE_FNUZ_FP8
|
||||
__assert_fnuz_support(interp);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_FP8_CVT_FAST_PATH
|
||||
// The conversion function is from rocblas
|
||||
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
|
||||
template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
|
||||
static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
|
||||
{
|
||||
fp8_storage_t i8data;
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
unsigned int i32val;
|
||||
unsigned char i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
unsigned int ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if constexpr(saturate)
|
||||
{
|
||||
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
|
||||
{
|
||||
if((val.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
}
|
||||
else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
|
||||
{ // OCP type
|
||||
if((val.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((val.i32val & 0x7F800000) != 0x7F800000)
|
||||
{ /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
|
||||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
|
||||
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
|
||||
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0]; // little endian
|
||||
}
|
||||
else
|
||||
{ // RNE CVT
|
||||
ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
|
||||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
|
||||
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
|
||||
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
|
||||
val.fval,
|
||||
ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
}
|
||||
return i8data;
|
||||
}
|
||||
#endif // CK_FP8_CVT_FAST_PATH
|
||||
|
||||
// The conversion function is from rocblas
|
||||
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
|
||||
// This has been modified to add double types conversion as well
|
||||
template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
|
||||
__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
|
||||
{
|
||||
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
|
||||
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
|
||||
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
|
||||
static_assert(is_half || is_float || is_double,
|
||||
"Only half, float and double can be cast to f8");
|
||||
|
||||
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
|
||||
|
||||
using T_bitwise = typename __hip_internal::conditional<
|
||||
sizeof(T) == 2,
|
||||
unsigned short int,
|
||||
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::
|
||||
type>::type;
|
||||
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
|
||||
|
||||
unsigned long long x{x_bitwise};
|
||||
|
||||
unsigned long long head, mantissa;
|
||||
int exponent, bias;
|
||||
unsigned int sign;
|
||||
unsigned long long fInf, mask;
|
||||
|
||||
if constexpr(sizeof(T) == 8)
|
||||
{
|
||||
head = x & 0xFFF0000000000000ull;
|
||||
mantissa = x & 0xFFFFFFFFFFFFFull;
|
||||
exponent = (head >> 52) & 0x7FF;
|
||||
sign = head >> 63;
|
||||
bias = 1023;
|
||||
fInf = 0x7FF0000000000000ull;
|
||||
mask = 0x7FFFFFFFFFFFFFFFull;
|
||||
}
|
||||
else if constexpr(sizeof(T) == 4)
|
||||
{
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
fInf = 0x7F800000;
|
||||
mask = 0x7FFFFFFF;
|
||||
}
|
||||
else
|
||||
{
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
fInf = 0x7C00;
|
||||
mask = 0x7FFF;
|
||||
}
|
||||
unsigned int signed_inf = 0;
|
||||
unsigned int nan = 0;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
|
||||
nan = 0x80;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(we == 4)
|
||||
{ // e4m3
|
||||
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
|
||||
}
|
||||
else
|
||||
{ // e5m2
|
||||
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
|
||||
}
|
||||
nan = (sign << 7) + 0x7f;
|
||||
}
|
||||
// Max values
|
||||
unsigned long long ifmax = 0;
|
||||
if constexpr(sizeof(T) == 8)
|
||||
{
|
||||
if constexpr(we == 5)
|
||||
{ // 57344
|
||||
ifmax = 0x40EC000000000000ull;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
{ // 240
|
||||
ifmax = 0x406E000000000000ull;
|
||||
}
|
||||
else
|
||||
{ // 448
|
||||
ifmax = 0x407C000000000000ull;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(sizeof(T) == 4)
|
||||
{
|
||||
if constexpr(we == 5)
|
||||
{
|
||||
ifmax = 0x47600000;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
ifmax = 0x43700000;
|
||||
}
|
||||
else
|
||||
{
|
||||
ifmax = 0x43E00000;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(we == 5)
|
||||
{
|
||||
ifmax = 0x7B00;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
ifmax = 0x5B80;
|
||||
}
|
||||
else
|
||||
{
|
||||
ifmax = 0x5F00;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Deal with inf and NaNs
|
||||
if((x & fInf) == fInf)
|
||||
{
|
||||
if constexpr(is_fnuz)
|
||||
return signed_inf;
|
||||
|
||||
return mantissa != 0 ? nan : signed_inf;
|
||||
}
|
||||
|
||||
if((x & mask) > ifmax)
|
||||
{
|
||||
return signed_inf;
|
||||
}
|
||||
|
||||
if(x == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if(exponent == 0)
|
||||
{ // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = f8_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
}
|
||||
else
|
||||
{ // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= f8_denormal_act_exponent)
|
||||
{
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
|
||||
actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
}
|
||||
else
|
||||
{ // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
||||
// for this case, act_exponent could be larger. Just
|
||||
// that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
(1ull << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part and
|
||||
make something not midpoint look like midpoint. For example, the fp16 number
|
||||
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
|
||||
by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1ull << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent =
|
||||
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
|
||||
bool odd =
|
||||
mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if(f8_exponent == 0)
|
||||
{
|
||||
if((1ull << mfmt) & mantissa)
|
||||
{
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1ull << (mfmt + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - 1;
|
||||
if(f8_exponent > max_exp)
|
||||
{
|
||||
if constexpr(clip)
|
||||
{
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if(f8_exponent == 0 && mantissa == 0)
|
||||
return is_fnuz ? 0 : (sign << 7);
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert float to @p fp8_storage_t
|
||||
*
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam sat saturation of fp8
|
||||
* \param f float number
|
||||
* \return fp8_storage_t
|
||||
*/
|
||||
template <ck_fp8_interpretation_t interp,
|
||||
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
|
||||
bool stochastic_rounding = false>
|
||||
#if CK_FP8_CVT_FAST_PATH
|
||||
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
{
|
||||
__is_interpret_supported(interp);
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
}
|
||||
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
f, rng);
|
||||
#else
|
||||
#if CK_USE_OCP_FP8
|
||||
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
{
|
||||
#else
|
||||
__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
{
|
||||
#endif
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
}
|
||||
|
||||
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
|
||||
{
|
||||
return cast_to_f8<float,
|
||||
3,
|
||||
4,
|
||||
true,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(f, rng);
|
||||
}
|
||||
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
|
||||
{
|
||||
return cast_to_f8<float,
|
||||
2,
|
||||
5,
|
||||
true,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(f, rng);
|
||||
}
|
||||
else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
|
||||
{
|
||||
return cast_to_f8<float,
|
||||
3,
|
||||
4,
|
||||
false,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(f, rng);
|
||||
}
|
||||
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
|
||||
{
|
||||
return cast_to_f8<float,
|
||||
2,
|
||||
5,
|
||||
false,
|
||||
sat == ck_saturation_t::CK_SATFINITE,
|
||||
stochastic_rounding>(f, rng);
|
||||
}
|
||||
else
|
||||
{
|
||||
__hip_assert(false && "FP8 type is not supported by current target device");
|
||||
return 0;
|
||||
}
|
||||
#endif // CK_FP8_CVT_FAST_PATH
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief convert _Float16 to @p fp8_storage_t
|
||||
*
|
||||
* \tparam sat saturation of fp8
|
||||
* \tparam interp interpretation of fp8
|
||||
* \tparam stochastic_rounding switch between RNE and SR
|
||||
* \param x _Float16 value
|
||||
* \return fp8_storage_t
|
||||
*/
|
||||
template <ck_fp8_interpretation_t interp,
|
||||
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
|
||||
bool stochastic_rounding = false>
|
||||
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
|
||||
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
|
||||
#else
|
||||
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
|
||||
#endif
|
||||
{
|
||||
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
|
||||
}
|
||||
|
||||
} // namespace fp8_impl
|
||||
|
||||
// Declare a template function for fp8 conversion using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
|
||||
// convert fp32 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
// convert _Float16 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return bf8_ocp_t{
|
||||
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
|
||||
x)};
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
|
||||
// convert fp32 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
|
||||
{
|
||||
return f8_ocp_t{
|
||||
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
|
||||
x)};
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
// convert _Float16 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
|
||||
f8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
// convert _Float16 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, _Float16>(_Float16 x)
|
||||
{
|
||||
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
|
||||
bf8_ocp_t::default_saturation,
|
||||
true>(x)};
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
using f8_t = f8_ocp_t;
|
||||
using bf8_t = bf8_ocp_t;
|
||||
#define CK_FP8_TYPE_FNUZ 0
|
||||
#define CK_FP8_TYPE_OCP 1
|
||||
#else
|
||||
using f8_t = f8_fnuz_t;
|
||||
using bf8_t = bf8_fnuz_t;
|
||||
#define CK_FP8_TYPE_FNUZ 1
|
||||
#define CK_FP8_TYPE_OCP 0
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for gfx94x models
|
||||
// Define the common macro for MI300 models
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/amd_ck_fp8.hpp"
|
||||
#include "ck/utility/statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -10,8 +11,6 @@ namespace ck {
|
||||
using bhalf_t = ushort;
|
||||
using half_t = _Float16;
|
||||
using int4_t = _BitInt(4);
|
||||
using f8_t = _BitInt(8);
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
|
||||
inline constexpr auto next_pow2(uint32_t x)
|
||||
{
|
||||
@@ -19,14 +18,15 @@ inline constexpr auto next_pow2(uint32_t x)
|
||||
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
|
||||
}
|
||||
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
|
||||
// native types: bool
|
||||
template <typename T>
|
||||
inline constexpr bool is_native_type()
|
||||
{
|
||||
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
|
||||
is_same<T, bool>::value;
|
||||
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
|
||||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
// vector_type
|
||||
@@ -166,16 +166,30 @@ struct scalar_type<int4_t>
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct scalar_type<f8_t>
|
||||
struct scalar_type<f8_fnuz_t>
|
||||
{
|
||||
using type = f8_t;
|
||||
using type = f8_fnuz_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf8_t>
|
||||
struct scalar_type<bf8_fnuz_t>
|
||||
{
|
||||
using type = bf8_t;
|
||||
using type = bf8_fnuz_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<f8_ocp_t>
|
||||
{
|
||||
using type = f8_ocp_t::data_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf8_ocp_t>
|
||||
{
|
||||
using type = bf8_ocp_t::data_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
@@ -1010,60 +1024,203 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct non_native_vector_base
|
||||
template <typename T, index_t N, typename Enable = void>
|
||||
struct non_native_vector_base;
|
||||
|
||||
template <typename T>
|
||||
struct nnvb_data_t_selector
|
||||
{
|
||||
using type = non_native_vector_base<T, N>;
|
||||
using type = unsigned _BitInt(8 * sizeof(T));
|
||||
};
|
||||
|
||||
__host__ __device__ non_native_vector_base() = default;
|
||||
__host__ __device__ non_native_vector_base(const type&) = default;
|
||||
__host__ __device__ non_native_vector_base(type&&) = default;
|
||||
__host__ __device__ ~non_native_vector_base() = default;
|
||||
template <>
|
||||
struct nnvb_data_t_selector<f8_ocp_t>
|
||||
{
|
||||
using type = f8_ocp_t::data_type;
|
||||
};
|
||||
template <>
|
||||
struct nnvb_data_t_selector<bf8_ocp_t>
|
||||
{
|
||||
using type = bf8_ocp_t::data_type;
|
||||
};
|
||||
|
||||
T d[N];
|
||||
template <typename T, index_t N>
|
||||
struct non_native_vector_base<
|
||||
T,
|
||||
N,
|
||||
std::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>
|
||||
{
|
||||
using data_t = typename nnvb_data_t_selector<T>::type; // select data_t based on the size of T
|
||||
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
|
||||
using data_v = data_t __attribute__((ext_vector_type(N)));
|
||||
using type = non_native_vector_base<T, N>;
|
||||
|
||||
union alignas(next_pow2(N * sizeof(T)))
|
||||
{
|
||||
data_v dN; // storage vector;
|
||||
StaticallyIndexedArray<data_t, N> dxN;
|
||||
StaticallyIndexedArray<T, N> dTxN;
|
||||
StaticallyIndexedArray<data_v, 1> dNx1;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {}
|
||||
__host__ __device__ constexpr non_native_vector_base(T f)
|
||||
: non_native_vector_base(bit_cast<data_t>(f))
|
||||
{
|
||||
}
|
||||
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
|
||||
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
|
||||
|
||||
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
|
||||
__host__ __device__ constexpr operator data_t() const
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return data_.dxN[Number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
return data_.dxN; // XXX this should cause an error
|
||||
}
|
||||
}
|
||||
__host__ __device__ constexpr operator T() const
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return data_.dTxN[Number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
return data_.dTxN; // XXX this should cause an error
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same_v<X, data_t>)
|
||||
{
|
||||
return data_.dxN;
|
||||
}
|
||||
else if constexpr(is_same_v<X, T>)
|
||||
{
|
||||
return data_.dTxN;
|
||||
}
|
||||
else if constexpr(is_same_v<X, data_v>)
|
||||
{
|
||||
return data_.dNx1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same_v<X, data_t>)
|
||||
{
|
||||
return data_.dxN;
|
||||
}
|
||||
else if constexpr(is_same_v<X, T>)
|
||||
{
|
||||
return data_.dTxN;
|
||||
}
|
||||
else if constexpr(is_same_v<X, data_v>)
|
||||
{
|
||||
return data_.dNx1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<non_native_vector_base<T, N>>;
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
|
||||
{
|
||||
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
// non-native vector_type implementation
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
using d1_t = T;
|
||||
using d1_nnv_t = non_native_vector_base<T, 1>;
|
||||
using type = d1_nnv_t;
|
||||
|
||||
union alignas(next_pow2(1 * sizeof(T)))
|
||||
{
|
||||
d1_t d1_;
|
||||
StaticallyIndexedArray<d1_t, 1> d1x1_;
|
||||
d1_nnv_t d1_nnv_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
__host__ __device__ constexpr vector_type() : data_{d1_t{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
return data_.d1x1_;
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
return data_.d1x1_;
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d1_t = T;
|
||||
using d1_nnv_t = non_native_vector_base<T, 1>;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
|
||||
using type = d2_t;
|
||||
|
||||
@@ -1081,10 +1238,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
@@ -1101,10 +1259,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
@@ -1122,9 +1281,10 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d1_t = T;
|
||||
using d1_nnv_t = non_native_vector_base<T, 1>;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
|
||||
using type = d4_t;
|
||||
|
||||
@@ -1143,10 +1303,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
@@ -1167,10 +1328,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
@@ -1192,10 +1354,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
using d1_t = T;
|
||||
using d1_nnv_t = non_native_vector_base<T, 1>;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
|
||||
using type = d8_t;
|
||||
|
||||
@@ -1215,11 +1378,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
@@ -1244,11 +1408,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
@@ -1274,11 +1439,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
using d16_t = non_native_vector_base<T, 16>;
|
||||
using d1_t = T;
|
||||
using d1_nnv_t = non_native_vector_base<T, 1>;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
using d16_t = non_native_vector_base<T, 16>;
|
||||
|
||||
using type = d16_t;
|
||||
|
||||
@@ -1299,12 +1465,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
@@ -1333,12 +1499,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
|
||||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
@@ -1632,20 +1798,70 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// f8
|
||||
using f8x2_t = typename vector_type<f8_t, 2>::type;
|
||||
using f8x4_t = typename vector_type<f8_t, 4>::type;
|
||||
using f8x8_t = typename vector_type<f8_t, 8>::type;
|
||||
using f8x16_t = typename vector_type<f8_t, 16>::type;
|
||||
using f8x32_t = typename vector_type<f8_t, 32>::type;
|
||||
using f8x64_t = typename vector_type<f8_t, 64>::type;
|
||||
using f8x2_fnuz_t = typename vector_type<f8_fnuz_t, 2>::type;
|
||||
using f8x4_fnuz_t = typename vector_type<f8_fnuz_t, 4>::type;
|
||||
using f8x8_fnuz_t = typename vector_type<f8_fnuz_t, 8>::type;
|
||||
using f8x16_fnuz_t = typename vector_type<f8_fnuz_t, 16>::type;
|
||||
using f8x32_fnuz_t = typename vector_type<f8_fnuz_t, 32>::type;
|
||||
using f8x64_fnuz_t = typename vector_type<f8_fnuz_t, 64>::type;
|
||||
|
||||
// bf8
|
||||
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
|
||||
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
|
||||
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
using bf8x2_fnuz_t = typename vector_type<bf8_fnuz_t, 2>::type;
|
||||
using bf8x4_fnuz_t = typename vector_type<bf8_fnuz_t, 4>::type;
|
||||
using bf8x8_fnuz_t = typename vector_type<bf8_fnuz_t, 8>::type;
|
||||
using bf8x16_fnuz_t = typename vector_type<bf8_fnuz_t, 16>::type;
|
||||
using bf8x32_fnuz_t = typename vector_type<bf8_fnuz_t, 32>::type;
|
||||
using bf8x64_fnuz_t = typename vector_type<bf8_fnuz_t, 64>::type;
|
||||
|
||||
// f8
|
||||
using f8x2_ocp_t = typename vector_type<f8_ocp_t, 2>::type;
|
||||
using f8x4_ocp_t = typename vector_type<f8_ocp_t, 4>::type;
|
||||
using f8x8_ocp_t = typename vector_type<f8_ocp_t, 8>::type;
|
||||
using f8x16_ocp_t = typename vector_type<f8_ocp_t, 16>::type;
|
||||
using f8x32_ocp_t = typename vector_type<f8_ocp_t, 32>::type;
|
||||
using f8x64_ocp_t = typename vector_type<f8_ocp_t, 64>::type;
|
||||
|
||||
// bf8
|
||||
using bf8x2_ocp_t = typename vector_type<bf8_ocp_t, 2>::type;
|
||||
using bf8x4_ocp_t = typename vector_type<bf8_ocp_t, 4>::type;
|
||||
using bf8x8_ocp_t = typename vector_type<bf8_ocp_t, 8>::type;
|
||||
using bf8x16_ocp_t = typename vector_type<bf8_ocp_t, 16>::type;
|
||||
using bf8x32_ocp_t = typename vector_type<bf8_ocp_t, 32>::type;
|
||||
using bf8x64_ocp_t = typename vector_type<bf8_ocp_t, 64>::type;
|
||||
|
||||
#if CK_FP8_TYPE_OCP
|
||||
// f8
|
||||
using f8x2_t = f8x2_ocp_t;
|
||||
using f8x4_t = f8x4_ocp_t;
|
||||
using f8x8_t = f8x8_ocp_t;
|
||||
using f8x16_t = f8x16_ocp_t;
|
||||
using f8x32_t = f8x32_ocp_t;
|
||||
using f8x64_t = f8x64_ocp_t;
|
||||
|
||||
// bf8
|
||||
using bf8x2_t = bf8x2_ocp_t;
|
||||
using bf8x4_t = bf8x4_ocp_t;
|
||||
using bf8x8_t = bf8x8_ocp_t;
|
||||
using bf8x16_t = bf8x16_ocp_t;
|
||||
using bf8x32_t = bf8x32_ocp_t;
|
||||
using bf8x64_t = bf8x64_ocp_t;
|
||||
#elif CK_FP8_TYPE_FNUZ
|
||||
// f8
|
||||
using f8x2_t = f8x2_fnuz_t;
|
||||
using f8x4_t = f8x4_fnuz_t;
|
||||
using f8x8_t = f8x8_fnuz_t;
|
||||
using f8x16_t = f8x16_fnuz_t;
|
||||
using f8x32_t = f8x32_fnuz_t;
|
||||
using f8x64_t = f8x64_fnuz_t;
|
||||
|
||||
// bf8
|
||||
using bf8x2_t = bf8x2_fnuz_t;
|
||||
using bf8x4_t = bf8x4_fnuz_t;
|
||||
using bf8x8_t = bf8x8_fnuz_t;
|
||||
using bf8x16_t = bf8x16_fnuz_t;
|
||||
using bf8x32_t = bf8x32_fnuz_t;
|
||||
using bf8x64_t = bf8x64_fnuz_t;
|
||||
#endif
|
||||
|
||||
// u8
|
||||
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
|
||||
@@ -1702,7 +1918,7 @@ struct NumericLimits<int4_t>
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_t>
|
||||
struct NumericLimits<f8_fnuz_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 8
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
@@ -1715,17 +1931,17 @@ struct NumericLimits<f8_t>
|
||||
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
|
||||
|
||||
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
|
||||
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
|
||||
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
|
||||
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
|
||||
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf8_t>
|
||||
struct NumericLimits<bf8_fnuz_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 16
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
@@ -1738,13 +1954,59 @@ struct NumericLimits<bf8_t>
|
||||
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
|
||||
|
||||
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
|
||||
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_ocp_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
|
||||
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
|
||||
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
|
||||
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Lowest()
|
||||
{
|
||||
return bit_cast<f8_ocp_t>(binary_lowest);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
|
||||
{
|
||||
return bit_cast<f8_ocp_t>(binary_qnan);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf8_ocp_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
|
||||
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
|
||||
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
|
||||
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Lowest()
|
||||
{
|
||||
return bit_cast<bf8_ocp_t>(binary_lowest);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
|
||||
{
|
||||
return bit_cast<bf8_ocp_t>(binary_qnan);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -1787,7 +2049,7 @@ struct NumericUtils<half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<f8_t>
|
||||
struct NumericUtils<f8_fnuz_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
@@ -1796,13 +2058,28 @@ struct NumericUtils<f8_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bf8_t>
|
||||
struct NumericUtils<bf8_fnuz_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
static constexpr int bias = 16; // negative zero nan mode
|
||||
// static constexpr int bias = 15; // ieee mode
|
||||
};
|
||||
template <>
|
||||
struct NumericUtils<f8_ocp_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
static constexpr int bias = 7;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bf8_ocp_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
static constexpr int bias = 15;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bhalf_t>
|
||||
|
||||
@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
|
||||
static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
static inline __host__ bool isnan(int4_t x)
|
||||
@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
|
||||
static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
|
||||
|
||||
static inline __device__ half_t sqrt(half_t x)
|
||||
{
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Pseudo random number generator
|
||||
@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
|
||||
}
|
||||
|
||||
// version for fp16
|
||||
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
|
||||
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
|
||||
}
|
||||
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T,
|
||||
uint32_t seed_t,
|
||||
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
|
||||
template <
|
||||
typename T,
|
||||
uint32_t seed_t,
|
||||
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
std::ignore = id;
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/utility/array.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for gfx94x models
|
||||
// Define the common macro for MI300 models
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
|
||||
{
|
||||
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
|
||||
{
|
||||
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
|
||||
}
|
||||
|
||||
// Convert X to Y
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert_sp(X x)
|
||||
@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
|
||||
// convert fp32 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::cast_to_f8<half_t,
|
||||
f8_fnuz_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == f8_rounding_mode::stochastic)>(x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::cast_to_f8<float,
|
||||
bf8_fnuz_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == f8_rounding_mode::stochastic)>(x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8 with stochastic rounding
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<bf8_t>(type_convert<float>(x));
|
||||
return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::cast_to_f8<half_t,
|
||||
bf8_fnuz_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == f8_rounding_mode::stochastic)>(x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
|
||||
|
||||
// convert fp32 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
|
||||
rng);
|
||||
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
|
||||
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_rne<f8_t>(type_convert<float>(x));
|
||||
return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::cast_to_f8<half_t,
|
||||
f8_fnuz_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == f8_rounding_mode::stochastic)>(x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
|
||||
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::cast_to_f8<float,
|
||||
bf8_fnuz_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == f8_rounding_mode::stochastic)>(x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8 with rounding to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
|
||||
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_rne<bf8_t>(type_convert<float>(x));
|
||||
return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return utils::cast_to_f8<half_t,
|
||||
bf8_fnuz_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == f8_rounding_mode::stochastic)>(x, rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
|
||||
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_t>(x);
|
||||
return f8_convert_sr<f8_fnuz_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_t>(x);
|
||||
return f8_convert_rne<f8_fnuz_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
@@ -392,30 +427,44 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
|
||||
return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
const auto i16val = bit_cast<uint16_t>(x);
|
||||
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
const auto f8x2_v = vector_type<f8_t, 2>(x);
|
||||
const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
|
||||
vector_type<float, 2> f32x2_v;
|
||||
f32x2_v.template AsType<float>()(Number<0>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
|
||||
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
|
||||
f32x2_v.template AsType<float>()(Number<1>{}) =
|
||||
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
|
||||
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
|
||||
f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
|
||||
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
|
||||
{
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
|
||||
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
|
||||
#else
|
||||
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
|
||||
x.AsType<fp8_storage_t>()[Number<0>{}]),
|
||||
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
|
||||
x.AsType<fp8_storage_t>()[Number<1>{}])};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
|
||||
{
|
||||
@@ -428,42 +477,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_t>(x);
|
||||
return f8_convert_sr<f8_fnuz_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_t>(x);
|
||||
return f8_convert_rne<f8_fnuz_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to fp8
|
||||
template <>
|
||||
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<f8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<f8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
|
||||
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_t>(x);
|
||||
return f8_convert_sr<bf8_fnuz_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_t>(x);
|
||||
return f8_convert_rne<bf8_fnuz_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp32 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
|
||||
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
@@ -473,31 +544,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_t>(x);
|
||||
return f8_convert_sr<bf8_fnuz_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_t>(x);
|
||||
return f8_convert_rne<bf8_fnuz_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert fp16 to bf8
|
||||
template <>
|
||||
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
|
||||
{
|
||||
#if CK_USE_SR_F8_CONVERSION
|
||||
return f8_convert_sr<bf8_ocp_t>(x);
|
||||
#else
|
||||
return f8_convert_rne<bf8_ocp_t>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bf8 to fp16
|
||||
template <>
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
ComputeTypeA v_a = 0;
|
||||
ComputeTypeB v_b = 0;
|
||||
AccDataType v_acc{0};
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
CDataType v_c{0};
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endforeach()
|
||||
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
|
||||
message("removing mha instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -346,7 +346,7 @@ if(CK_DEVICE_CONV_INSTANCES)
|
||||
endif()
|
||||
if(CK_DEVICE_MHA_INSTANCES)
|
||||
set(gpu_list ${INST_TARGETS})
|
||||
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
|
||||
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
|
||||
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
|
||||
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
|
||||
target_compile_features(device_mha_operations PUBLIC)
|
||||
|
||||
@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, false>{});
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_f8_instances(
|
||||
@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, true>{});
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
|
||||
ck::utils::FillConstant<ADataType>{type_convert<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{type_convert<BDataType>(1.f)}(b_k_n);
|
||||
break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
|
||||
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_fp8 test_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_fp8 PRIVATE utility)
|
||||
|
||||
|
||||
add_custom_target(test_fp8)
|
||||
|
||||
if (CK_USE_OCP_FP8)
|
||||
add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_fp8_ocp PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_bf8_ocp test_bf8_ocp.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_bf8_ocp PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_dependencies(test_fp8 test_fp8_ocp)
|
||||
add_dependencies(test_fp8 test_bf8_ocp)
|
||||
endif()
|
||||
add_gtest_executable(test_bf8 test_bf8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_bf8 PRIVATE utility)
|
||||
|
||||
if (CK_USE_FNUZ_FP8)
|
||||
add_gtest_executable(test_fp8_fnuz test_fp8_fnuz.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_fp8_fnuz PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_bf8_fnuz test_bf8_fnuz.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_bf8_fnuz PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_dependencies(test_fp8 test_fp8_fnuz)
|
||||
add_dependencies(test_fp8 test_bf8_fnuz)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_custom_type test_custom_type.cpp)
|
||||
|
||||
@@ -5,158 +5,169 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bf8_t;
|
||||
using ck::bf8_fnuz_t;
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(BF8, NumericLimits)
|
||||
TEST(BF8FNUZ, NumericLimits)
|
||||
{
|
||||
// constants given for negative zero nan mode
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Min(), type_convert<bf8_fnuz_t>(0x04));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Max(), type_convert<bf8_fnuz_t>(0x7F));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Lowest(), type_convert<bf8_fnuz_t>(0xFF));
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(), type_convert<bf8_fnuz_t>(0x80));
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP32Nearest)
|
||||
TEST(BF8FNUZ, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_t>(0.0f)), abs_tol);
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(0.0f)), abs_tol);
|
||||
// don't run the next test on gfx11 devices
|
||||
#ifndef CK_SKIP_FLAKY_F8_TEST
|
||||
// convert minimal float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::min())),
|
||||
type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
#endif
|
||||
// convert maximal bf8_t to float and check if equal to 57344.0
|
||||
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_rne<bf8_t>(57344.0f)), abs_tol);
|
||||
|
||||
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
|
||||
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
|
||||
ASSERT_NEAR(
|
||||
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
|
||||
// convert maximal float to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(57344.0f,
|
||||
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::max())),
|
||||
ASSERT_NEAR(max_bf8_t_float,
|
||||
type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to bf8_t and check if it is qNan
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
f8_convert_rne<bf8_t>(std::numeric_limits<float>::infinity()),
|
||||
// convert inf float to bf8_fnuz_t and check if it is qNan
|
||||
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to bf8 and back, check if holds
|
||||
float pos_float = 0.0000762939f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative norm float value to bf8 and back, check if holds
|
||||
float neg_float = -0.0000610351f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to bf8 and back, check if holds
|
||||
pos_float = 0.0000305175f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to bf8 and back, check if holds
|
||||
neg_float = -0.0000152587f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP32Stochastic)
|
||||
TEST(BF8FNUZ, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol);
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())),
|
||||
type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to float and check if equal to 57344.0
|
||||
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol);
|
||||
|
||||
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
|
||||
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
|
||||
ASSERT_NEAR(
|
||||
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
|
||||
// convert maximal float to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(57344.0f,
|
||||
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())),
|
||||
ASSERT_NEAR(max_bf8_t_float,
|
||||
type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to bf8_t and check if it is qNan
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()),
|
||||
// convert inf float to bf8_fnuz_t and check if it is qNan
|
||||
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to bf8 and back, check if holds
|
||||
float pos_float = 0.0000762939f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative norm float value to bf8 and back, check if holds
|
||||
float neg_float = -0.0000610351f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to bf8 and back, check if holds
|
||||
pos_float = 0.0000305175f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to bf8 and back, check if holds
|
||||
neg_float = -0.0000152587f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP16Nearest)
|
||||
TEST(BF8FNUZ, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{0.0})), abs_tol);
|
||||
ASSERT_NEAR(
|
||||
half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to fp16 and check if equal to 57344.0
|
||||
|
||||
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
|
||||
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
|
||||
ASSERT_NEAR(
|
||||
half_t{57344.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{57344.0})), abs_tol);
|
||||
max_bf8_t_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
|
||||
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(half_t{57344.0},
|
||||
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
ASSERT_NEAR(max_bf8_t_half,
|
||||
type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
// convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to bf8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0000762939};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative norm fp16 value to bf8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0000610351};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to bf8 and back, check if holds
|
||||
pos_half = half_t{0.0000305175};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to bf8 and back, check if holds
|
||||
neg_half = half_t{-0.0000152587};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BF8, ConvertFP16Stochastic)
|
||||
TEST(BF8FNUZ, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol);
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to bf8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal bf8_t to fp16 and check if equal to 57344.0
|
||||
|
||||
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
|
||||
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
|
||||
ASSERT_NEAR(
|
||||
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol);
|
||||
max_bf8_t_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
|
||||
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
|
||||
ASSERT_NEAR(half_t{57344.0},
|
||||
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
ASSERT_NEAR(max_bf8_t_half,
|
||||
type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(type_convert<bf8_t>(0x80),
|
||||
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
// convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to bf8 and back, check if holds
|
||||
half_t pos_half = half_t{0.0000762939};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative norm fp16 value to bf8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.0000610351};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to bf8 and back, check if holds
|
||||
pos_half = half_t{0.0000305175};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to bf8 and back, check if holds
|
||||
neg_half = half_t{-0.0000152587};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
|
||||
}
|
||||
268
test/data_type/test_bf8_ocp.cpp
Normal file
268
test/data_type/test_bf8_ocp.cpp
Normal file
@@ -0,0 +1,268 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bf8_ocp_t;
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(BF8OCP, NumericLimits)
|
||||
{ // constants given for OCP FP8
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Min(),
|
||||
type_convert<bf8_ocp_t>(0x04)); // 0b00000100 = 2^-14
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
|
||||
type_convert<bf8_ocp_t>(0x7B)); // 0b01111011 = 57344
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Lowest(),
|
||||
type_convert<bf8_ocp_t>(0xFB)); // 0b11111011 = -57344
|
||||
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::QuietNaN().data,
|
||||
type_convert<bf8_ocp_t>(0x7D).data); // 0b01111101
|
||||
EXPECT_FALSE(ck::NumericLimits<bf8_ocp_t>::QuietNaN() ==
|
||||
ck::NumericLimits<bf8_ocp_t>::QuietNaN());
|
||||
EXPECT_TRUE(ck::fp8_is_inf(type_convert<bf8_ocp_t>(0xFC)) &&
|
||||
ck::fp8_is_inf(type_convert<bf8_ocp_t>(0x7C)));
|
||||
}
|
||||
|
||||
TEST(BF8OCP, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
|
||||
// convert 0 float to bfp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(0.0f)), 0.0f);
|
||||
|
||||
// convert minimal float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
|
||||
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
|
||||
|
||||
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
|
||||
ASSERT_NEAR(
|
||||
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
|
||||
|
||||
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_bf8_t_float,
|
||||
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::max())),
|
||||
0.0f);
|
||||
|
||||
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
|
||||
f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
|
||||
|
||||
// positive normal float value to bf8 and back, check if holds
|
||||
float pos_float = 0.0000762939f; // 10*2^-17
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_float)), abs_tol);
|
||||
|
||||
// negative smallest normal bf8 value to bf8 and back, check if holds
|
||||
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
|
||||
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), 0.0f);
|
||||
|
||||
// positive subnorm float value to bf8 and back, check if holds
|
||||
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
|
||||
ASSERT_NEAR(
|
||||
pos_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
|
||||
|
||||
// min subnorm bf8 value to bf8 and back, check if holds
|
||||
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
|
||||
ASSERT_NEAR(
|
||||
min_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
|
||||
|
||||
// smaller than min subnorm bf8 value to bf8 must be zero
|
||||
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
|
||||
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
|
||||
|
||||
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
|
||||
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
|
||||
}
|
||||
|
||||
TEST(BF8OCP, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
|
||||
// convert 0 float to bfp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_ocp_t>(0.0f)), 0.0f);
|
||||
|
||||
// convert minimal float to bf8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
|
||||
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
|
||||
|
||||
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
|
||||
ASSERT_NEAR(
|
||||
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
|
||||
|
||||
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_bf8_t_float,
|
||||
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::max())),
|
||||
0.0f);
|
||||
|
||||
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
|
||||
f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
|
||||
|
||||
// positive normal float value to bf8 and back, check if holds
|
||||
float pos_float = 0.0000762939f; // 10*2^-17
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_float)), abs_tol);
|
||||
|
||||
// negative smallest normal bf8 value to bf8 and back, check if holds
|
||||
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
|
||||
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), 0.0f);
|
||||
|
||||
// positive subnorm float value to bf8 and back, check if holds
|
||||
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
|
||||
ASSERT_NEAR(
|
||||
pos_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
|
||||
|
||||
// min subnorm bf8 value to bf8 and back, check if holds
|
||||
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
|
||||
ASSERT_NEAR(
|
||||
min_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
|
||||
|
||||
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
|
||||
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
|
||||
ASSERT_NEAR(0.0f,
|
||||
type_convert<float>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
|
||||
0.0000152587890625f);
|
||||
|
||||
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
|
||||
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
|
||||
}
|
||||
|
||||
TEST(BF8OCP, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
constexpr half_t half_t_tol = 1e-3;
|
||||
constexpr half_t half_t_zero = 0.0;
|
||||
|
||||
// convert 0 half_t to bfp8 and back, check if holds
|
||||
ASSERT_NEAR(
|
||||
half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(half_t_zero)), half_t_zero);
|
||||
|
||||
// convert minimal half_t to bf8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
|
||||
half_t_tol);
|
||||
|
||||
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
|
||||
|
||||
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
|
||||
ASSERT_NEAR(max_bf8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_half_t)),
|
||||
half_t_zero);
|
||||
|
||||
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_bf8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
|
||||
half_t_zero);
|
||||
|
||||
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(
|
||||
ck::NumericLimits<bf8_ocp_t>::Max(),
|
||||
f8_convert_rne<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
|
||||
|
||||
// positive normal bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
|
||||
ASSERT_NEAR(
|
||||
pos_norm_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
|
||||
|
||||
// negative smallest normal bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
|
||||
ASSERT_NEAR(
|
||||
neg_min_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
|
||||
|
||||
// positive subnorm bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
|
||||
ASSERT_NEAR(pos_subnorm_bf8,
|
||||
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)),
|
||||
half_t_zero);
|
||||
|
||||
// min subnorm bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t min_subnorm_bf8{-0.0000152587890625f}; //-2^-16
|
||||
ASSERT_NEAR(min_subnorm_bf8,
|
||||
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)),
|
||||
half_t_zero);
|
||||
|
||||
// smaller than min subnorm bf8 value to bf8 must be zero
|
||||
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
|
||||
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
|
||||
|
||||
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
|
||||
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
|
||||
}
|
||||
|
||||
TEST(BF8OCP, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
constexpr half_t half_t_tol = 1e-3;
|
||||
constexpr half_t half_t_zero = 0.0;
|
||||
constexpr auto min_subnorm_bf8 = 0.0000152587890625f; // 2^-16
|
||||
|
||||
// convert 0 half_t to bfp8 and back, check if holds
|
||||
ASSERT_NEAR(
|
||||
half_t_zero, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t_zero)), half_t_zero);
|
||||
|
||||
// convert minimal half_t (6.103515625e-05) to fp8 and back
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
|
||||
half_t_zero);
|
||||
|
||||
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
|
||||
|
||||
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
|
||||
ASSERT_NEAR(max_bf8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_half_t)),
|
||||
half_t_zero);
|
||||
|
||||
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_bf8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
|
||||
half_t_zero);
|
||||
|
||||
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(
|
||||
ck::NumericLimits<bf8_ocp_t>::Max(),
|
||||
f8_convert_sr<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
|
||||
|
||||
// positive normal bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
|
||||
ASSERT_NEAR(
|
||||
pos_norm_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
|
||||
|
||||
// negative smallest normal bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
|
||||
ASSERT_NEAR(
|
||||
neg_min_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
|
||||
|
||||
// positive subnorm bf8 value to bf8 and back, check if holds
|
||||
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
|
||||
ASSERT_NEAR(pos_subnorm_bf8,
|
||||
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)),
|
||||
half_t_zero);
|
||||
|
||||
// min subnorm bf8 value to bf8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{-min_subnorm_bf8},
|
||||
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t{-min_subnorm_bf8})),
|
||||
half_t_zero);
|
||||
|
||||
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
|
||||
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
|
||||
ASSERT_NEAR(half_t_zero,
|
||||
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
|
||||
half_t{min_subnorm_bf8});
|
||||
|
||||
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
|
||||
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
|
||||
}
|
||||
@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape)
|
||||
test_vec.at(num_elem * i + 1));
|
||||
});
|
||||
}
|
||||
|
||||
#if CK_USE_OCP_FP8
|
||||
|
||||
TEST(FP8OCP, TestSize)
|
||||
{
|
||||
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
|
||||
ASSERT_EQ(sizeof(f8_t), sizeof(ck::fp8_storage_t));
|
||||
ASSERT_EQ(sizeof(vector_type<f8_t, 2>), sizeof(vector_type<ck::fp8_storage_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<f8_t, 4>), sizeof(vector_type<ck::fp8_storage_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<f8_t, 8>), sizeof(vector_type<ck::fp8_storage_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<f8_t, 16>), sizeof(vector_type<ck::fp8_storage_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<f8_t, 32>), sizeof(vector_type<ck::fp8_storage_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<f8_t, 64>), sizeof(vector_type<ck::fp8_storage_t, 64>));
|
||||
}
|
||||
|
||||
TEST(FP8OCP, TestAsType)
|
||||
{
|
||||
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
|
||||
|
||||
// test size
|
||||
std::array<float, 8> test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16};
|
||||
constexpr int size = test_vec.size();
|
||||
|
||||
// reference vector
|
||||
vector_type<f8_t, size> right_vec;
|
||||
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<f8_t>()(Number<i>{}), f8_t{0}); });
|
||||
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
|
||||
});
|
||||
|
||||
// copy the vector
|
||||
vector_type<f8_t, size> left_vec{right_vec};
|
||||
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
|
||||
ck::type_convert<f8_t>(test_vec.at(i)));
|
||||
});
|
||||
|
||||
ck::non_native_vector_base<ck::f8_ocp_t, 2> nnvb_f8x2(ck::type_convert<f8_t>(-10.0f));
|
||||
ASSERT_EQ(nnvb_f8x2.template AsType<f8_t>()(Number<0>{}), ck::type_convert<f8_t>(-10.0f));
|
||||
ASSERT_EQ(nnvb_f8x2.template AsType<f8_t>()(Number<1>{}), ck::type_convert<f8_t>(-10.0f));
|
||||
}
|
||||
|
||||
TEST(FP8OCP, TestAsTypeReshape)
|
||||
{
|
||||
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
|
||||
|
||||
// test size
|
||||
std::array<float, 8> test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16};
|
||||
constexpr int size = test_vec.size();
|
||||
|
||||
// reference vector
|
||||
vector_type<f8_t, size> right_vec;
|
||||
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<f8_t>()(Number<i>{}), f8_t{0}); });
|
||||
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
|
||||
});
|
||||
|
||||
// copy the first half of a vector
|
||||
vector_type<f8_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<f8_t, size / 2>::type>()(Number<0>{})};
|
||||
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
|
||||
ck::type_convert<f8_t>(test_vec.at(i)));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(BF8OCP, TestSize)
|
||||
{
|
||||
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
|
||||
ASSERT_EQ(sizeof(bf8_t), sizeof(ck::fp8_storage_t));
|
||||
ASSERT_EQ(sizeof(vector_type<bf8_t, 2>), sizeof(vector_type<ck::fp8_storage_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<bf8_t, 4>), sizeof(vector_type<ck::fp8_storage_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<bf8_t, 8>), sizeof(vector_type<ck::fp8_storage_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<bf8_t, 16>), sizeof(vector_type<ck::fp8_storage_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<bf8_t, 32>), sizeof(vector_type<ck::fp8_storage_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<bf8_t, 64>), sizeof(vector_type<ck::fp8_storage_t, 64>));
|
||||
}
|
||||
|
||||
TEST(BF8OCP, TestAsType)
|
||||
{
|
||||
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
|
||||
|
||||
// test size
|
||||
std::array<float, 8> test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16};
|
||||
constexpr int size = test_vec.size();
|
||||
|
||||
// reference vector
|
||||
vector_type<bf8_t, size> right_vec;
|
||||
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<bf8_t>()(Number<i>{}), bf8_t{0}); });
|
||||
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
|
||||
});
|
||||
|
||||
// copy the vector
|
||||
vector_type<bf8_t, size> left_vec{right_vec};
|
||||
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
|
||||
ck::type_convert<bf8_t>(test_vec.at(i)));
|
||||
});
|
||||
|
||||
ck::non_native_vector_base<bf8_t, 2> nnvb_bf8x2(ck::type_convert<bf8_t>(-10.0f));
|
||||
ASSERT_EQ(nnvb_bf8x2.template AsType<bf8_t>()(Number<0>{}), ck::type_convert<bf8_t>(-10.0f));
|
||||
ASSERT_EQ(nnvb_bf8x2.template AsType<bf8_t>()(Number<1>{}), ck::type_convert<bf8_t>(-10.0f));
|
||||
}
|
||||
|
||||
TEST(BF8OCP, TestAsTypeReshape)
|
||||
{
|
||||
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
|
||||
|
||||
// test size
|
||||
std::array<float, 8> test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16};
|
||||
constexpr int size = test_vec.size();
|
||||
|
||||
// reference vector
|
||||
vector_type<bf8_t, size> right_vec;
|
||||
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<bf8_t>()(Number<i>{}), bf8_t{0}); });
|
||||
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
|
||||
});
|
||||
|
||||
// copy the first half of a vector
|
||||
vector_type<bf8_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<bf8_t, size / 2>::type>()(Number<0>{})};
|
||||
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
|
||||
ck::type_convert<bf8_t>(test_vec.at(i)));
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -7,154 +7,171 @@
|
||||
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::f8_t;
|
||||
using ck::f8_fnuz_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(FP8, NumericLimits)
|
||||
TEST(FP8FNUZ, NumericLimits)
|
||||
{
|
||||
// constants given for negative zero nan mode
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Min(), type_convert<f8_fnuz_t>(0x08));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Max(), type_convert<f8_fnuz_t>(0x7F));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Lowest(), type_convert<f8_fnuz_t>(0xFF));
|
||||
EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::QuietNaN(), type_convert<f8_fnuz_t>(0x80));
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Nearest)
|
||||
TEST(FP8FNUZ, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_t>(0.0f)), abs_tol);
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_fnuz_t>(0.0f)), abs_tol);
|
||||
// don't run the next test on gfx11 devices
|
||||
#ifndef CK_SKIP_FLAKY_F8_TEST
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::min())),
|
||||
type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
#endif
|
||||
// convert maximal f8_t to float and check if equal to 240.0
|
||||
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_rne<f8_t>(240.0f)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(240.0f,
|
||||
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::max())),
|
||||
|
||||
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
|
||||
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
|
||||
ASSERT_NEAR(
|
||||
max_f8_t_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(max_f8_t_float)), abs_tol);
|
||||
|
||||
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
|
||||
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
|
||||
// fp8 qNAN (no saturation).
|
||||
|
||||
// convert maximal float to fp8 and back, check if clipped to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_float,
|
||||
type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
f8_convert_rne<f8_t>(std::numeric_limits<float>::infinity()),
|
||||
// convert inf float to f8_fnuz_t and check if it is qNan
|
||||
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to fp8 and back, check if holds
|
||||
float pos_float = 0.017578125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative norm float value to fp8 and back, check if holds
|
||||
float neg_float = -0.015625f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to fp8 and back, check if holds
|
||||
pos_float = 0.00390625f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to fp8 and back, check if holds
|
||||
neg_float = -0.001953125f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP32Stochastic)
|
||||
TEST(FP8FNUZ, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_t>(0.0f)), abs_tol);
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_fnuz_t>(0.0f)), abs_tol);
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::min())),
|
||||
type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to float and check if equal to 240.0
|
||||
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_sr<f8_t>(240.0f)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(240.0f,
|
||||
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
|
||||
|
||||
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
|
||||
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
|
||||
ASSERT_NEAR(
|
||||
max_f8_t_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(max_f8_t_float)), abs_tol);
|
||||
// convert maximal float to fp8 and back, check if clipped to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_float,
|
||||
type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::max())),
|
||||
abs_tol);
|
||||
// convert inf float to f8_t and check if it is qNan
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()),
|
||||
// convert inf float to f8_fnuz_t and check if it is qNan
|
||||
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
|
||||
abs_tol);
|
||||
// positive norm float value to fp8 and back, check if holds
|
||||
float pos_float = 0.017578125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative norm float value to fp8 and back, check if holds
|
||||
float neg_float = -0.015625f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
|
||||
// positive subnorm float value to fp8 and back, check if holds
|
||||
pos_float = 0.00390625f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
|
||||
// negative subnorm float value to fp8 and back, check if holds
|
||||
neg_float = -0.001953125f;
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP16Nearest)
|
||||
TEST(FP8FNUZ, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{0.0})), abs_tol);
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to fp16 and check if equal to 240.0
|
||||
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{240.0})), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(half_t{240.0},
|
||||
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
|
||||
const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
|
||||
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
|
||||
ASSERT_NEAR(
|
||||
max_f8_t_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(max_f8_t_half)), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_half,
|
||||
type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
// convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.017578125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative norm fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.015625};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to fp8 and back, check if holds
|
||||
pos_half = half_t{0.00390625};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to fp8 and back, check if holds
|
||||
neg_half = half_t{-0.001953125};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(FP8, ConvertFP16Stochastic)
|
||||
TEST(FP8FNUZ, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-3;
|
||||
// convert 0 fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{0.0})), abs_tol);
|
||||
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(half_t{0.0})), abs_tol);
|
||||
// convert minimal fp16 to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Min())),
|
||||
type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
|
||||
abs_tol);
|
||||
// convert maximal f8_t to fp16 and check if equal to 240.0
|
||||
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{240.0})), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
|
||||
ASSERT_NEAR(half_t{240.0},
|
||||
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
|
||||
|
||||
const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
|
||||
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
|
||||
ASSERT_NEAR(
|
||||
max_f8_t_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(max_f8_t_half)), abs_tol);
|
||||
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_half,
|
||||
type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
|
||||
abs_tol);
|
||||
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(type_convert<f8_t>(0x80),
|
||||
f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
// convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
|
||||
ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
|
||||
f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
|
||||
abs_tol);
|
||||
// positive norm fp16 value to fp8 and back, check if holds
|
||||
half_t pos_half = half_t{0.017578125};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative norm fp16 value to fp8 and back, check if holds
|
||||
half_t neg_half = half_t{-0.015625};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
|
||||
// positive subnorm fp16 value to fp8 and back, check if holds
|
||||
pos_half = half_t{0.00390625};
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
|
||||
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
|
||||
// negative subnorm fp16 value to fp8 and back, check if holds
|
||||
neg_half = half_t{-0.001953125};
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
|
||||
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
|
||||
}
|
||||
250
test/data_type/test_fp8_ocp.cpp
Normal file
250
test/data_type/test_fp8_ocp.cpp
Normal file
@@ -0,0 +1,250 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::f8_ocp_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(FP8OCP, NumericLimits)
|
||||
{
|
||||
// constants given for OCP FP8
|
||||
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Min(),
|
||||
type_convert<f8_ocp_t>(0x08)); // 0b00001000 = 2^-6
|
||||
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Max(), type_convert<f8_ocp_t>(0x7E)); // 0b01111110 = 448
|
||||
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Lowest(),
|
||||
type_convert<f8_ocp_t>(0xFE)); // 0b11111110 = -448
|
||||
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::QuietNaN().data,
|
||||
type_convert<f8_ocp_t>(0x7F).data); // 0b01111111
|
||||
EXPECT_FALSE(ck::NumericLimits<f8_ocp_t>::QuietNaN() ==
|
||||
ck::NumericLimits<f8_ocp_t>::QuietNaN());
|
||||
}
|
||||
|
||||
TEST(FP8OCP, ConvertFP32Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_ocp_t>(0.0f)), 0.0f);
|
||||
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
|
||||
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
|
||||
|
||||
// convert maximal f8_ocp_t to float and check if equal to fp8 max
|
||||
ASSERT_NEAR(
|
||||
max_f8_t_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(max_f8_t_float)), 0.0f);
|
||||
|
||||
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_f8_t_float,
|
||||
type_convert<float>(f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::max())),
|
||||
0.0f);
|
||||
|
||||
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
|
||||
f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::infinity()));
|
||||
|
||||
// positive norm float value to fp8 and back, check if holds
|
||||
float pos_float = 0.017578125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
|
||||
|
||||
// smallest normal fp8 value to fp8 and back, check if holds
|
||||
float neg_float = -0.015625f; //-2^-6
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
|
||||
|
||||
// positive subnorm float value to fp8 and back, check if holds
|
||||
pos_float = 0.00390625f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
|
||||
|
||||
// min subnorm fp8 value to fp8 and back, check if holds
|
||||
neg_float = -0.001953125f; //-2^-9
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
|
||||
|
||||
// smaller than min subnorm fp8 value to fp8 must be zero
|
||||
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
|
||||
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
|
||||
|
||||
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
|
||||
auto f8_nan = f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
|
||||
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, ConvertFP32Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
float abs_tol = 1e-6;
|
||||
// convert 0 float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(0.0f)), 0.0f);
|
||||
|
||||
// convert minimal float to fp8 and back, check if holds
|
||||
ASSERT_NEAR(std::numeric_limits<float>::min(),
|
||||
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::min())),
|
||||
abs_tol);
|
||||
|
||||
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
|
||||
|
||||
// convert maximal f8_ocp_t to float and check if equal to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(max_f8_t_float)), 0.0f);
|
||||
|
||||
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_f8_t_float,
|
||||
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::max())),
|
||||
0.0f);
|
||||
|
||||
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
|
||||
f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::infinity()));
|
||||
|
||||
// positive norm float value to fp8 and back, check if holds
|
||||
float pos_float = 0.017578125f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
|
||||
|
||||
// smallest normal fp8 value to fp8 and back, check if holds
|
||||
float neg_float = -0.015625f; //-2^-6
|
||||
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(neg_float)), 0.0f);
|
||||
|
||||
// positive subnorm float value to fp8 and back, check if holds
|
||||
pos_float = 0.00390625f;
|
||||
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
|
||||
|
||||
// min subnorm fp8 value to fp8 and back, check if holds
|
||||
constexpr auto min_subnorm_fp8 = -0.001953125f; //-2^-9
|
||||
ASSERT_NEAR(
|
||||
min_subnorm_fp8, type_convert<float>(f8_convert_sr<f8_ocp_t>(min_subnorm_fp8)), 0.0f);
|
||||
|
||||
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
|
||||
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
|
||||
ASSERT_NEAR(
|
||||
0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm)), 0.001953125f);
|
||||
|
||||
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
|
||||
auto f8_nan = f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
|
||||
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, ConvertFP16Nearest)
|
||||
{
|
||||
// fix the tolerance value
|
||||
constexpr half_t half_t_tol = 1e-3;
|
||||
constexpr half_t half_t_zero = 0.0;
|
||||
// convert 0 half_t to fp8 and back, check if holds
|
||||
ASSERT_NEAR(
|
||||
half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(half_t_zero)), half_t_zero);
|
||||
|
||||
// convert minimal half_t to fp8 and back, check if holds
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
|
||||
half_t_tol);
|
||||
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
|
||||
|
||||
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(max_f8_t_half_t)),
|
||||
half_t_zero);
|
||||
|
||||
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_f8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
|
||||
half_t_zero);
|
||||
|
||||
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(
|
||||
ck::NumericLimits<f8_ocp_t>::Max(),
|
||||
f8_convert_rne<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
|
||||
|
||||
// positive norm half_t value to fp8 and back, check if holds
|
||||
half_t pos_half_t{0.017578125f};
|
||||
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
|
||||
|
||||
// smallest normal fp8 value to fp8 and back, check if holds
|
||||
half_t neg_half_t{-0.015625f}; //-2^-6
|
||||
ASSERT_NEAR(
|
||||
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
|
||||
|
||||
// positive subnorm half_t value to fp8 and back, check if holds
|
||||
pos_half_t = half_t{0.00390625f};
|
||||
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
|
||||
|
||||
// min subnorm fp8 value to fp8 and back, check if holds
|
||||
neg_half_t = half_t{-0.001953125f}; //-2^-9
|
||||
ASSERT_NEAR(
|
||||
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
|
||||
|
||||
// smaller than min subnorm fp8 value to fp8 must be zero
|
||||
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
|
||||
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
|
||||
|
||||
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
|
||||
auto f8_nan = f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
|
||||
}
|
||||
|
||||
TEST(FP8OCP, ConvertFP16Stochastic)
|
||||
{
|
||||
// fix the tolerance value
|
||||
constexpr half_t half_t_tol = 1e-3;
|
||||
constexpr half_t half_t_zero = 0.0;
|
||||
constexpr auto min_subnorm_fp8 = 0.001953125f; // 2^-9
|
||||
|
||||
// convert 0 half_t to fp8 and back, check if holds
|
||||
ASSERT_NEAR(
|
||||
half_t_zero, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(half_t_zero)), half_t_zero);
|
||||
|
||||
// convert minimal half_t (6.103515625e-05) to fp8 and back
|
||||
// alternates between 0 and 2^-9 (0.001953125)
|
||||
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
|
||||
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
|
||||
type_convert<half_t>(min_subnorm_fp8));
|
||||
|
||||
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
|
||||
|
||||
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
|
||||
ASSERT_NEAR(max_f8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(max_f8_t_half_t)),
|
||||
half_t_zero);
|
||||
|
||||
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
|
||||
ASSERT_NEAR(max_f8_t_half_t,
|
||||
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
|
||||
half_t_zero);
|
||||
|
||||
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
|
||||
ASSERT_EQ(
|
||||
ck::NumericLimits<f8_ocp_t>::Max(),
|
||||
f8_convert_sr<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
|
||||
|
||||
// positive norm half_t value to fp8 and back, check if holds
|
||||
half_t pos_half_t{0.017578125f};
|
||||
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
|
||||
|
||||
// smallest normal fp8 value to fp8 and back, check if holds
|
||||
half_t neg_half_t{-0.015625f}; //-2^-6
|
||||
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
|
||||
|
||||
// positive subnorm half_t value to fp8 and back, check if holds
|
||||
pos_half_t = half_t{0.00390625f};
|
||||
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
|
||||
|
||||
// min subnorm fp8 value to fp8 and back, check if holds
|
||||
neg_half_t = half_t{-min_subnorm_fp8}; //-2^-9
|
||||
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
|
||||
|
||||
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
|
||||
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
|
||||
ASSERT_NEAR(
|
||||
type_convert<float>(half_t_zero),
|
||||
type_convert<float>(type_convert<half_t>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm))),
|
||||
min_subnorm_fp8);
|
||||
|
||||
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
|
||||
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
|
||||
}
|
||||
@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
|
||||
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
|
||||
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
|
||||
|
||||
TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_F32, AvgPool2D_F32_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
|
||||
|
||||
@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
|
||||
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
|
||||
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
|
||||
|
||||
TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_F32, MaxPool2D_F32_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
|
||||
|
||||
Reference in New Issue
Block a user