diff --git a/CMakeLists.txt b/CMakeLists.txt index ef46d96f4d..3e1174ec04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,14 @@ rocm_setup_version(VERSION 0.2.0) include(TargetFlags) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) +option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) + +if(USE_BITINT_EXTENSION_INT4) + add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + add_compile_options(-Wno-bit-int-extension) + message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") +endif() + ## C++ enable_language(CXX) set(CMAKE_CXX_STANDARD 17) diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index cf2240ebc5..3c6cb56cce 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -42,3 +42,8 @@ target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) + +set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 97e5d38feb..7595b4402a 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -62,6 +62,14 @@ struct PassThrough { y = type_convert(x); } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + __host__ __device__ void operator()(int4_t& y, const int4_t& x) const + { + y = x; + } +#endif }; struct UnaryConvert @@ -111,9 +119,13 @@ struct UnarySquare template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same::value || is_same::value, + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || is_same_v +#endif + , "Data type is not supported by this operation!"); - y = x * x; }; }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 4b578bf149..24bb13d7fb 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -9,6 +9,9 @@ namespace ck { using bhalf_t = ushort; using half_t = _Float16; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using int4_t = _BitInt(4); +#endif // vector_type template @@ -130,6 +133,15 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct scalar_type +{ + using type = int4_t; + static constexpr index_t vector_size = 1; +}; +#endif + // template struct vector_type @@ -1030,4 +1042,16 @@ struct NumericLimits __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } }; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-7); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-7); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + } // namespace ck diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index fc264117f0..84a057815f 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -42,6 +42,14 @@ static inline __host__ half_t abs(half_t x) return abs_x; }; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __host__ int4_t abs(int4_t x) +{ + int4_t sgn = x >> (4 - 1); + return (x ^ sgn) - sgn; +} +#endif + static inline __host__ bool isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); }; @@ -65,6 +73,14 @@ static inline __host__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __host__ bool isnan(int4_t x) +{ + (void)x; + return false; +}; +#endif + static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); }; @@ -89,6 +105,15 @@ static inline __device__ int32_t abs(int32_t x) return (x ^ sgn) - sgn; }; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __device__ int4_t abs(int4_t x) +{ + int4_t sgn = x >> (4 - 1); + + return (x ^ sgn) - sgn; +}; +#endif + static inline __device__ half_t abs(half_t x) { return ::__habs(x); }; static inline __device__ bool isnan(float x) { return ::isnan(x); }; @@ -107,6 +132,14 @@ static inline __device__ bool isnan(int32_t x) return false; }; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __device__ bool isnan(int4_t x) +{ + (void)x; + return false; +}; +#endif + static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); }; static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f391e478c4..50cb730f69 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -51,3 +51,4 @@ add_subdirectory(grouped_convnd_fwd) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(layernorm) +add_subdirectory(data_type) diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt new file mode 100644 index 0000000000..088fbfec71 --- /dev/null +++ b/test/data_type/CMakeLists.txt @@ -0,0 +1,4 @@ +if (USE_BITINT_EXTENSION_INT4) + add_gtest_executable(test_int4 int4.cpp) + target_link_libraries(test_int4 PRIVATE utility) +endif() diff --git a/test/data_type/int4.cpp b/test/data_type/int4.cpp new file mode 100644 index 0000000000..9d9cc294ca --- /dev/null +++ b/test/data_type/int4.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" + +using ck::int4_t; + +TEST(Int4, BaseArithmetic) +{ + int4_t a{1}; + int4_t b{-2}; + EXPECT_EQ(a + a, int4_t{2}); + EXPECT_EQ(a - a, int4_t{0}); + EXPECT_EQ(a + b, int4_t{-1}); + EXPECT_EQ(a - b, int4_t{3}); + EXPECT_EQ(a * a, int4_t{1}); + EXPECT_EQ(a * b, int4_t{-2}); + EXPECT_EQ(b * b, int4_t{4}); + EXPECT_EQ(a / b, int4_t{0}); + a = int4_t{4}; + EXPECT_EQ(a / b, int4_t{-2}); + b = int4_t{2}; + EXPECT_EQ(a % b, int4_t{0}); +} + +TEST(Int4, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), int4_t{-7}); + EXPECT_EQ(ck::NumericLimits::Max(), int4_t{7}); + EXPECT_EQ(ck::NumericLimits::Lowest(), int4_t{-7}); +} + +TEST(Int4, MathOpsV2) +{ + int4_t a{4}; + int4_t b{-5}; + + EXPECT_EQ(ck::math::abs(a), int4_t{4}); + EXPECT_EQ(ck::math::abs(b), int4_t{5}); + EXPECT_FALSE(ck::math::isnan(b)); +}