mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
int4 data type (#364)
* Introduce int4 data type.
* Add unit-tests for int4
* Compile int4 UT only when int4 enabled.
* clang-format
Co-authored-by: Adam Osewski <aosewski@amd.com>
[ROCm/composable_kernel commit: e00149ac67]
This commit is contained in:
@@ -21,6 +21,14 @@ rocm_setup_version(VERSION 0.2.0)
|
||||
include(TargetFlags)
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
|
||||
|
||||
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
add_compile_options(-Wno-bit-int-extension)
|
||||
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
endif()
|
||||
|
||||
## C++
|
||||
enable_language(CXX)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
@@ -42,3 +42,8 @@ target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
|
||||
set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
@@ -62,6 +62,14 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -111,9 +119,13 @@ struct UnarySquare
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value,
|
||||
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> ||
|
||||
is_same_v<T, int8_t>
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| is_same_v<T, int4_t>
|
||||
#endif
|
||||
,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x * x;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -9,6 +9,9 @@ namespace ck {
|
||||
|
||||
using bhalf_t = ushort;
|
||||
using half_t = _Float16;
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
using int4_t = _BitInt(4);
|
||||
#endif
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
@@ -130,6 +133,15 @@ struct scalar_type<int8_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct scalar_type<int4_t>
|
||||
{
|
||||
using type = int4_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
#endif
|
||||
|
||||
//
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
@@ -1030,4 +1042,16 @@ struct NumericLimits<half_t>
|
||||
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct NumericLimits<int4_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int4_t Min() { return int4_t(-7); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-7); }
|
||||
};
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -42,6 +42,14 @@ static inline __host__ half_t abs(half_t x)
|
||||
return abs_x;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
static inline __host__ int4_t abs(int4_t x)
|
||||
{
|
||||
int4_t sgn = x >> (4 - 1);
|
||||
return (x ^ sgn) - sgn;
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline __host__ bool isnan(float x) { return std::isnan(x); };
|
||||
|
||||
static inline __host__ bool isnan(double x) { return std::isnan(x); };
|
||||
@@ -65,6 +73,14 @@ static inline __host__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
static inline __host__ bool isnan(int4_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
|
||||
@@ -89,6 +105,15 @@ static inline __device__ int32_t abs(int32_t x)
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
static inline __device__ int4_t abs(int4_t x)
|
||||
{
|
||||
int4_t sgn = x >> (4 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
#endif
|
||||
|
||||
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
|
||||
|
||||
static inline __device__ bool isnan(float x) { return ::isnan(x); };
|
||||
@@ -107,6 +132,14 @@ static inline __device__ bool isnan(int32_t x)
|
||||
return false;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
static inline __device__ bool isnan(int4_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
|
||||
|
||||
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
|
||||
|
||||
@@ -51,3 +51,4 @@ add_subdirectory(grouped_convnd_fwd)
|
||||
add_subdirectory(block_to_ctile_map)
|
||||
add_subdirectory(softmax)
|
||||
add_subdirectory(layernorm)
|
||||
add_subdirectory(data_type)
|
||||
|
||||
4
test/data_type/CMakeLists.txt
Normal file
4
test/data_type/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
if (USE_BITINT_EXTENSION_INT4)
|
||||
add_gtest_executable(test_int4 int4.cpp)
|
||||
target_link_libraries(test_int4 PRIVATE utility)
|
||||
endif()
|
||||
44
test/data_type/int4.cpp
Normal file
44
test/data_type/int4.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
|
||||
using ck::int4_t;
|
||||
|
||||
TEST(Int4, BaseArithmetic)
|
||||
{
|
||||
int4_t a{1};
|
||||
int4_t b{-2};
|
||||
EXPECT_EQ(a + a, int4_t{2});
|
||||
EXPECT_EQ(a - a, int4_t{0});
|
||||
EXPECT_EQ(a + b, int4_t{-1});
|
||||
EXPECT_EQ(a - b, int4_t{3});
|
||||
EXPECT_EQ(a * a, int4_t{1});
|
||||
EXPECT_EQ(a * b, int4_t{-2});
|
||||
EXPECT_EQ(b * b, int4_t{4});
|
||||
EXPECT_EQ(a / b, int4_t{0});
|
||||
a = int4_t{4};
|
||||
EXPECT_EQ(a / b, int4_t{-2});
|
||||
b = int4_t{2};
|
||||
EXPECT_EQ(a % b, int4_t{0});
|
||||
}
|
||||
|
||||
TEST(Int4, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck::NumericLimits<int4_t>::Min(), int4_t{-7});
|
||||
EXPECT_EQ(ck::NumericLimits<int4_t>::Max(), int4_t{7});
|
||||
EXPECT_EQ(ck::NumericLimits<int4_t>::Lowest(), int4_t{-7});
|
||||
}
|
||||
|
||||
TEST(Int4, MathOpsV2)
|
||||
{
|
||||
int4_t a{4};
|
||||
int4_t b{-5};
|
||||
|
||||
EXPECT_EQ(ck::math::abs(a), int4_t{4});
|
||||
EXPECT_EQ(ck::math::abs(b), int4_t{5});
|
||||
EXPECT_FALSE(ck::math::isnan(b));
|
||||
}
|
||||
Reference in New Issue
Block a user