diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 245fb7244f..e709fed23d 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -117,12 +117,8 @@ using bf16_raw_t = uint16_t; CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_rtn_raw(float f) { - union - { - float fp32; - uint32_t int32; - } u = {f}; - if(~u.int32 & 0x7f800000) + uint32_t bits = bit_cast(f); + if(~bits & 0x7f800000) { // When the exponent bits are not all 1s, then the value is zero, normal, // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus @@ -140,9 +136,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f) // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, // incrementing it causes it to become an exponent of 0xFF and a mantissa // of 0x00, which is Inf, the next higher value to the unrounded value. - u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even } - else if(u.int32 & 0xffff) + else if(bits & 0xffff) { // When all of the exponent bits are 1, the value is Inf or NaN. // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero @@ -152,9 +148,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f) // lower 16 bits of the mantissa are 1, we set the least significant bit // of the bfloat16 mantissa, in order to preserve signaling NaN in case // the bloat16's mantissa bits are all 0. - u.int32 |= 0x10000; // Preserve signaling NaN + bits |= 0x10000; // Preserve signaling NaN } - return uint16_t(u.int32 >> 16); + return uint16_t(bits >> 16); } CK_TILE_HOST @@ -225,24 +221,16 @@ uint16_t float_to_bf16_rta_asm(float f) CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_truc_nan_raw(float f) { - union - { - float fp32; - uint32_t int32; - } u = {f}; - return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); + uint32_t bits = bit_cast(f); + return static_cast(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff)); } // Fast truncate instead of rounding, RTZ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_truc_raw(float f) { - union - { - float fp32; - uint32_t int32; - } u = {f}; - return uint16_t(u.int32 >> 16); + uint32_t bits = bit_cast(f); + return static_cast(bits >> 16); } template @@ -287,7 +275,7 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { -#if defined(__gfx950__) +#if CK_TILE_USE_LLVM_BUILTIN_BF16 return static_cast(f); #else return bit_cast(float_to_bf16_raw(f, constant{})); diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 221592ee10..ea8ba4557e 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -7,9 +7,26 @@ #include #include +#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1 +#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0 +#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0 + namespace ck_tile { namespace element_wise { +// Generalized constexpr lookup table generator +template +constexpr std::array make_lookup_table_impl(F&& func, std::index_sequence) +{ + return {func(Is)...}; +} + +template +constexpr std::array make_lookup_table(F&& func) +{ + return make_lookup_table_impl(std::forward(func), std::make_index_sequence{}); +} + /** * @brief Fast int4x4 to fp16x8_t data type conversion based on paper * "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production" @@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale) */ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) { +#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16 + // This approach fails validation in GEMM tests. uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); static constexpr uint32_t fp32_base = 0x4B000000; @@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); return res; +#else + // Lookup table for bf16_t values corresponding to int4 values -8 to 7 + constexpr auto bf16_lookup_table = make_lookup_table( + [](int i) { return bit_cast(float_to_bf16_rtn_raw(i - 8)); }); + + return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf], + bf16_lookup_table[(q >> 16) & 0xf], + bf16_lookup_table[(q >> 4) & 0xf], + bf16_lookup_table[(q >> 20) & 0xf]}; +#endif } +#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8 /** * @brief This function converts 8 packed 4-bit integers into 8 fp8 values. * @@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) return bit_cast((static_cast(tmp_res_high) << 32) | tmp_res_low); } +#else +CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q) +{ + // The approach below can be used once this compiler issue is resolved: + // "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported" + // Lookup table for fp8_t values corresponding to int4 values -8 to 7 + constexpr auto fp8_lookup_table = make_lookup_table( + [](int i) { return impl::cast_to_f8(i - 8, 0); }); + + return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf], + fp8_lookup_table[(q >> 16) & 0xf], + fp8_lookup_table[(q >> 4) & 0xf], + fp8_lookup_table[(q >> 20) & 0xf]}; +} +#endif CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src) { @@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) return res; } +#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8 /** * @brief This function converts 8 packed 4-bit integers into 8 bf8 values. * @@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) return bit_cast((static_cast(tmp_res_high) << 32) | tmp_res_low); } +#else +CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) +{ + // The approach below can be used once this compiler issue is resolved: + // "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported" + // Lookup table for bf8_t values corresponding to int4 values -8 to 7 + constexpr auto bf8_lookup_table = make_lookup_table( + [](int i) { return impl::cast_to_f8(i - 8, 0); }); + + return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf], + bf8_lookup_table[(q >> 16) & 0xf], + bf8_lookup_table[(q >> 4) & 0xf], + bf8_lookup_table[(q >> 20) & 0xf]}; +} +#endif struct PassThroughPack8 { @@ -300,17 +361,27 @@ struct PassThroughPack8 CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const { y.lo = i4_to_bhalf4(bit_cast(x)); - y.hi = i4_to_bhalf4(bit_cast(x) >> 16); + y.hi = i4_to_bhalf4(bit_cast(x) >> 8); } CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const { +#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8 y = amd_assembly_i4_to_fp8x8(bit_cast(x)); +#else + y.lo = i4_to_fp8x4(bit_cast(x)); + y.hi = i4_to_fp8x4(bit_cast(x) >> 8); +#endif } CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const { +#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8 y = amd_assembly_i4_to_bf8x8(bit_cast(x)); +#else + y.lo = i4_to_bf8x4(bit_cast(x)); + y.hi = i4_to_bf8x4(bit_cast(x) >> 8); +#endif } constexpr const static bool is_pack8_invocable = true; }; diff --git a/test/ck_tile/elementwise/CMakeLists.txt b/test/ck_tile/elementwise/CMakeLists.txt index 5fca0eb801..860a23a62a 100644 --- a/test/ck_tile/elementwise/CMakeLists.txt +++ b/test/ck_tile/elementwise/CMakeLists.txt @@ -1,6 +1,3 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) - if(result EQUAL 0) - target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility) - endif() -endif() \ No newline at end of file +endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp index 4e3033782c..23548f2f92 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp @@ -2,4 +2,11 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp index 61614fc6f5..cbf25a223a 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp @@ -2,4 +2,12 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = + run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp index c667c08053..7afeb4140d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp @@ -2,4 +2,13 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = run_gemm_combinations() && is_success; +#if 0 + is_success = + run_gemm_combinations() && is_success; +#endif + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp index 9a3498b7ea..0ba4b54403 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp @@ -2,4 +2,12 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "test_gemm_pipeline_basic_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = + run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 706035cabc..2c8a776f10 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -225,7 +225,7 @@ bool run_gemm_test(int argc, char* argv[]) } template -int run_gemm_combinations() +bool run_gemm_combinations() { // Define possible values for each parameter std::vector m_values = {"128", "1024"}; @@ -304,5 +304,5 @@ int run_gemm_combinations() } } } - return is_success ? EXIT_SUCCESS : EXIT_FAILURE; + return is_success; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index 52f6ea7026..cfcf3cb08c 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -263,6 +263,15 @@ struct GemmTypeConfig using CDataType = ck_tile::bf16_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + template <> struct GemmTypeConfig { @@ -281,6 +290,15 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template <> struct GemmTypeConfig { @@ -290,6 +308,15 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template <> struct GemmTypeConfig { diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp index 1336f6fd70..cf8cbd69c5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp @@ -6,4 +6,11 @@ #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp index 5d55f34b84..90f539f176 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp @@ -6,4 +6,12 @@ #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = + run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp index 0cebbcc721..727d43282a 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp @@ -6,4 +6,11 @@ #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp index 29fb5f87ce..8fbbec8e9f 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp @@ -6,4 +6,12 @@ #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = + run_gemm_combinations() && is_success; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp index e8a089d8ff..991f84788f 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_int8.cpp @@ -1,16 +1,15 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp index 043db10fb0..8abf05dbcf 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_pk_int4.cpp @@ -1,16 +1,15 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include - -#include -#include -#include - #include "ck_tile/host.hpp" #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" #include "test_gemm_pipeline_universal_run_test.inc" -int main() { return run_gemm_combinations(); } +int main() +{ + bool is_success = true; + is_success = + run_gemm_combinations() && is_success; + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index dfee45cdfd..d566f4eacb 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -357,5 +357,5 @@ int run_gemm_combinations() } } } - return is_success ? EXIT_SUCCESS : EXIT_FAILURE; + return is_success; }