diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index fa49f6ddd5..66f094557c 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -163,6 +163,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // set rounding to nearest even as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 0 +// shuffle pk_i4 values during conversion to optimize number of binary +// operations +#define CK_USE_PK4_LAYOUT_SHUFFLE 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 139f0057e4..f1055d1eff 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -16,7 +16,8 @@ namespace ck { // [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] // (https://arxiv.org/abs/2211.10017) and implementation: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__host__ __device__ inline half4_t pki4_to_half4(int q) +// Convert lower part of packed int4 -> int4 to half +__device__ inline half4_t i4_to_half4(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; @@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) return res.template AsType()[Number<0>{}]; } -__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale) +__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) { const int LO = 0x000f000f; const int HI = 0x00f000f0; @@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& return res.template AsType()[Number<0>{}]; } -__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) -{ -#if 1 - uint8_t x_u8 = ck::bit_cast(q); - uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4); - - const int EX = 0x64006400; - const int SUB = 0xE408E408; //-8 - - int lo = i4s | EX; - - return amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); -#else - uint8_t x_u8 = ck::bit_cast(q); - - vector_type res; - - half_t x_h = (x_u8 & 0x0f) - 8; - half_t x_l = ((x_u8 & 0xf0) >> 4) - 8; - - res.template AsType()(Number<0>{}) = x_l; - res.template AsType()(Number<1>{}) = x_h; - - return res.template AsType()[Number<0>{}]; -#endif -} - -__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) +__device__ inline bhalf4_t i4_to_bhalf4(int q) { uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); @@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) return res.template AsType()[Number<0>{}]; } -__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q) -{ - uint8_t x_u8 = ck::bit_cast(q); - - float x_h = ((x_u8 & 0x0f) >> 0) - 8.f; - float x_l = ((x_u8 & 0xf0) >> 4) - 8.f; - - vector_type res; - - res.template AsType()(Number<0>{}) = type_convert(x_l); - res.template AsType()(Number<1>{}) = type_convert(x_h); - - return res.template AsType()[Number<0>{}]; -} - namespace tensor_operation { namespace element_wise { @@ -159,11 +118,11 @@ struct PassThroughPack8 __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE vector_type result; - result.template AsType()(Number<0>{}) = pki4_to_half4(bit_cast(x)); - result.template AsType()(Number<1>{}) = pki4_to_half4(bit_cast(x) >> 8); + result.template AsType()(Number<0>{}) = i4_to_half4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_half4(bit_cast(x) >> 8); y = result.template AsType()[Number<0>{}]; #else @@ -171,13 +130,13 @@ struct PassThroughPack8 vector_type src{x}; dst.template AsType()(Number<0>{}) = - pki4_to_half2(src.template AsType()[Number<0>{}]); + type_convert(src.template AsType()[Number<0>{}]); dst.template AsType()(Number<1>{}) = - pki4_to_half2(src.template AsType()[Number<1>{}]); + type_convert(src.template AsType()[Number<1>{}]); dst.template AsType()(Number<2>{}) = - pki4_to_half2(src.template AsType()[Number<2>{}]); + type_convert(src.template AsType()[Number<2>{}]); dst.template AsType()(Number<3>{}) = - pki4_to_half2(src.template AsType()[Number<3>{}]); + type_convert(src.template AsType()[Number<3>{}]); y = dst.template AsType()[Number<0>{}]; #endif @@ -185,11 +144,11 @@ struct PassThroughPack8 __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE vector_type result; - result.template AsType()(Number<0>{}) = pki4_to_bhalf4(bit_cast(x)); - result.template AsType()(Number<1>{}) = pki4_to_bhalf4(bit_cast(x) >> 16); + result.template AsType()(Number<0>{}) = i4_to_bhalf4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_bhalf4(bit_cast(x) >> 16); y = result.template AsType()[Number<0>{}]; #else @@ -197,13 +156,13 @@ struct PassThroughPack8 vector_type src{x}; dst.template AsType()(Number<0>{}) = - pki4_to_bhalf2(src.template AsType()[Number<0>{}]); + type_convert(src.template AsType()[Number<0>{}]); dst.template AsType()(Number<1>{}) = - pki4_to_bhalf2(src.template AsType()[Number<1>{}]); + type_convert(src.template AsType()[Number<1>{}]); dst.template AsType()(Number<2>{}) = - pki4_to_bhalf2(src.template AsType()[Number<2>{}]); + type_convert(src.template AsType()[Number<2>{}]); dst.template AsType()(Number<3>{}) = - pki4_to_bhalf2(src.template AsType()[Number<3>{}]); + type_convert(src.template AsType()[Number<3>{}]); y = dst.template AsType()[Number<0>{}]; #endif @@ -219,12 +178,12 @@ struct DequantPack8 __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE vector_type result; - result.template AsType()(Number<0>{}) = pki4_to_half4_scale(bit_cast(x), z); + result.template AsType()(Number<0>{}) = i4_to_half4_scale(bit_cast(x), z); result.template AsType()(Number<1>{}) = - pki4_to_half4_scale(bit_cast(x) >> 8, z); + i4_to_half4_scale(bit_cast(x) >> 8, z); y = result.template AsType()[Number<0>{}]; #else @@ -232,13 +191,13 @@ struct DequantPack8 vector_type src{x}; dst.template AsType()(Number<0>{}) = - pki4_to_half2(src.template AsType()[Number<0>{}]); + type_convert(src.template AsType()[Number<0>{}]); dst.template AsType()(Number<1>{}) = - pki4_to_half2(src.template AsType()[Number<1>{}]); + type_convert(src.template AsType()[Number<1>{}]); dst.template AsType()(Number<2>{}) = - pki4_to_half2(src.template AsType()[Number<2>{}]); + type_convert(src.template AsType()[Number<2>{}]); dst.template AsType()(Number<3>{}) = - pki4_to_half2(src.template AsType()[Number<3>{}]); + type_convert(src.template AsType()[Number<3>{}]); y = dst.template AsType()[Number<0>{}]; #endif @@ -260,7 +219,7 @@ struct PassThroughPack2 __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const { -#if 1 +#if CK_USE_PK4_LAYOUT_SHUFFLE uint8_t x_u8 = ck::bit_cast(x); uint8_t x_l = (x_u8 & 0x0f) >> 0; uint8_t x_h = (x_u8 & 0xf0) >> 4; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 74187bfee9..a86de19645 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -7,6 +7,8 @@ #include "ck/utility/f8_utils.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/array.hpp" +#include "ck/utility/amd_inline_asm.hpp" +#include "ck/utility/type.hpp" namespace ck { // Define the common macro for MI300 models @@ -14,6 +16,26 @@ namespace ck { #define __gfx94__ #endif +namespace { +namespace details { + +[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y) +{ + half2_t vector_res; + + vector_res.x = x.x + y.x; + vector_res.y = x.y + y.y; + + return vector_res; +} + +[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y) +{ + return amd_assembly_pk_add_f16(x, y); +} +} // namespace details +} // namespace + // Declare a template function for bf16 conversion using RTN template __host__ __device__ constexpr Y bf16_convert_rtn(X x); @@ -520,13 +542,51 @@ template <> inline __host__ __device__ float2_t type_convert(pk_i4_t x) { uint8_t x_u8 = ck::bit_cast(x); - uint8_t x_l = (x_u8 & 0x0f) >> 0; - uint8_t x_h = (x_u8 & 0xf0) >> 4; - auto l_f32 = ck::type_convert(x_l); - auto h_f32 = ck::type_convert(x_h); + float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; + float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; - return {l_f32, h_f32}; +#ifdef CK_USE_PK4_LAYOUT_SHUFFLE + float2_t res = {x_h, x_l}; +#elif + float2_t res = {x_l, x_h}; +#endif + return res; +} + +template <> +inline __host__ __device__ half2_t type_convert(pk_i4_t x) +{ + uint8_t x_u8 = ck::bit_cast(x); +#ifdef CK_USE_PK4_LAYOUT_SHUFFLE + uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4); +#else + uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf); +#endif + + const int EX = 0x64006400; + const int SUB = 0xE408E408; //-8 + + int lo = i4s | EX; + + return details::pk_add_f16(bit_cast(lo), bit_cast(SUB)); +} + +template <> +inline __host__ __device__ bhalf2_t type_convert(pk_i4_t x) +{ + uint8_t x_u8 = ck::bit_cast(x); + + float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; + float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; + +#ifdef CK_USE_PK4_LAYOUT_SHUFFLE + bhalf2_t res = {type_convert(x_h), type_convert(x_l)}; +#else + bhalf2_t res = {type_convert(x_l), type_convert(x_h)}; +#endif + + return res; } template <> diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 5610c093ca..ba4f4b6e7d 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,6 +27,7 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 4c495ba781..c761fcb8c3 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -144,6 +144,10 @@ #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 #endif +#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE +#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1 +#endif + // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp new file mode 100644 index 0000000000..2ffcc36ced --- /dev/null +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/random.hpp" +#include +#include +#include "ck_tile/core/numeric/int8.hpp" + +#pragma once + +namespace ck_tile { + +// Packed 2xint4 +struct pk_int4_t +{ + using type = int8_t; + type data; + __host__ __device__ constexpr pk_int4_t() : data{type{}} {} + __host__ __device__ constexpr pk_int4_t(type init) : data{init} {} +}; + +// limits +template +struct numeric; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr pk_int4_t min() + { + constexpr uint8_t val = 0b10001000; + return pk_int4_t(bit_cast(val)); + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest() + { + constexpr uint8_t val = 0b10001000; + return pk_int4_t(bit_cast(val)); + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr pk_int4_t max() + { + constexpr uint8_t val = 0b01110111; + return pk_int4_t(bit_cast(val)); + } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon() + { + return 1; // not used + } + + CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error() + { + return 1; // not used + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity() + { + return 1; // not used + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN() + { + return 1; // not used + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN() + { + return 1; // not used + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min() + { + return 1; // not used + } + + CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; } +}; + +CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); + + float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; + float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; + +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + fp32x2_t res = {x_h, x_l}; +#elif + fp32x2_t res = {x_l, x_h}; +#endif + return res; +} + +CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4); +#elif + uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf); +#endif + const int EX = 0x64006400; + const int SUB = 0xE408E408; //-8 + + int lo = i4s | EX; + + return pk_add_f16(bit_cast(lo), bit_cast(SUB)); +} + +CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); + + float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; + float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; + +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + bf16x2_t res = {type_convert(x_h), type_convert(x_l)}; +#elif + bf16x2_t res = {type_convert(x_l), type_convert(x_h)}; +#endif + return res; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 3ef066a3eb..9aeb494919 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); #endif +__host__ fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) +{ + fp16x2_t vector_res; + + vector_res.x = x.x + y.x; + vector_res.y = x.y + y.y; + + return vector_res; +} + +__device__ fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) +{ + fp16x2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y)); + return c; +} + } // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 77cf35f667..8f9d7ac89b 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(image_to_column) add_subdirectory(gemm) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) +add_subdirectory(data_type) diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt new file mode 100644 index 0000000000..e489f306f7 --- /dev/null +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -0,0 +1,4 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp) +endif() diff --git a/test/ck_tile/data_type/test_pk_int4.cpp b/test/ck_tile/data_type/test_pk_int4.cpp new file mode 100644 index 0000000000..4e9fb20efc --- /dev/null +++ b/test/ck_tile/data_type/test_pk_int4.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include + +#include "ck_tile/core.hpp" + +using ck_tile::bf16_t; +using ck_tile::bf16x2_t; +using ck_tile::fp16x2_t; +using ck_tile::fp32x2_t; +using ck_tile::half_t; +using ck_tile::pk_int4_t; + +TEST(PackedInt4, ConvertToFloat) +{ +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + constexpr float first_input_val = 7.f; + constexpr float second_input_val = -1.f; +#else + constexpr float first_input_val = -1.f; + constexpr float second_input_val = 7.f; +#endif + uint8_t data = 0b11110111; // {-1, 7} + pk_int4_t in = ck_tile::bit_cast(data); + fp32x2_t out = ck_tile::pk_int4_t_to_fp32x2_t(in); + + EXPECT_EQ(out.x, first_input_val); + EXPECT_EQ(out.y, second_input_val); +} + +TEST(PackedInt4, ConvertToHalf) +{ +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + const half_t first_input_val = ck_tile::type_convert(7.f); + const half_t second_input_val = ck_tile::type_convert(-1.f); +#else + const half_t first_input_val = ck_tile::type_convert(-1.f); + const half_t second_input_val = ck_tile::type_convert(7.f); +#endif + uint8_t data = 0b11110111; // {-1, 7} + pk_int4_t in = ck_tile::bit_cast(data); + fp16x2_t out = ck_tile::pk_int4_t_to_halfx2_t(in); + + EXPECT_EQ(out.x, first_input_val); + EXPECT_EQ(out.y, second_input_val); +} + +TEST(PackedInt4, ConvertToBHalf) +{ +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + const bf16_t first_input_val = ck_tile::type_convert(7.f); + const bf16_t second_input_val = ck_tile::type_convert(-1.f); +#else + const bf16_t first_input_val = ck_tile::type_convert(-1.f); + const bf16_t second_input_val = ck_tile::type_convert(7.f); +#endif + uint8_t data = 0b11110111; // {-1, 7} + pk_int4_t in = ck_tile::bit_cast(data); + bf16x2_t out = ck_tile::pk_int4_t_to_bfloat16x2_t(in); + + EXPECT_EQ(out.x, first_input_val); + EXPECT_EQ(out.y, second_input_val); +} diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index a0ba3ed974..3b1dfecb48 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -50,3 +50,4 @@ endif() add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_bhalf test_bhalf.cpp) +add_gtest_executable(test_pk_i4 test_pk_i4.cpp) diff --git a/test/data_type/test_pk_i4.cpp b/test/data_type/test_pk_i4.cpp new file mode 100644 index 0000000000..d8d4d0e36d --- /dev/null +++ b/test/data_type/test_pk_i4.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include "gtest/gtest.h" +#include + +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/library/utility/device_memory.hpp" + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +using ck::bhalf2_t; +using ck::bhalf_t; +using ck::float2_t; +using ck::half2_t; +using ck::half4_t; +using ck::half_t; +using ck::pk_i4_t; +using ck::pk_i4x4_t; + +TEST(PackedInt4, ConvertToFloat) +{ +#ifdef CK_USE_PK4_LAYOUT_SHUFFLE + constexpr float first_input_val = 7.f; + constexpr float second_input_val = -1.f; +#else + constexpr float first_input_val = -1.f; + constexpr float second_input_val = 7.f; +#endif + uint8_t data = 0b11110111; // {-1, 7} + pk_i4_t in = ck::bit_cast(data); + float2_t out = ck::type_convert(in); + + EXPECT_EQ(out.x, first_input_val); + EXPECT_EQ(out.y, second_input_val); +} + +TEST(PackedInt4, ConvertToHalf) +{ +#ifdef CK_USE_PK4_LAYOUT_SHUFFLE + constexpr half_t first_input_val = ck::type_convert(7.f); + constexpr half_t second_input_val = ck::type_convert(-1.f); +#else + constexpr half_t first_input_val = ck::type_convert(-1.f); + constexpr half_t second_input_val = ck::type_convert(7.f); +#endif + uint8_t data = 0b11110111; // {-1, 7} + pk_i4_t in = ck::bit_cast(data); + half2_t out = ck::type_convert(in); + + EXPECT_EQ(out.x, first_input_val); + EXPECT_EQ(out.y, second_input_val); +} + +TEST(PackedInt4, ConvertToBHalf) +{ +#ifdef CK_USE_PK4_LAYOUT_SHUFFLE + const bhalf_t first_input_val = ck::type_convert(7.f); + const bhalf_t second_input_val = ck::type_convert(-1.f); +#else + const bhalf_t first_input_val = ck::type_convert(-1.f); + const bhalf_t second_input_val = ck::type_convert(7.f); +#endif + uint8_t data = 0b11110111; // {-1, 7} + pk_i4_t in = ck::bit_cast(data); + bhalf2_t out = ck::type_convert(in); + + EXPECT_EQ(out.x, first_input_val); + EXPECT_EQ(out.y, second_input_val); +}