[CK TILE] GEMM with packed i4 (#1885)

* [CK TILE] GEMM with packed i4

* Fixes

* fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-20 09:59:49 +01:00
committed by GitHub
parent 824e2c1737
commit 4d9973ec8e
32 changed files with 882 additions and 305 deletions

View File

@@ -376,14 +376,12 @@ struct numeric<bfloat16_t>
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<bfloat16_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int PackedSize = 1;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE

View File

@@ -207,9 +207,6 @@ using bf8_t = unsigned _BitInt(8);
using bf8_raw_t = uint8_t;
#endif
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<fp8_t>
{
@@ -225,6 +222,7 @@ struct numeric_traits<fp8_t>
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
static constexpr int PackedSize = 1;
};
template <>
@@ -242,6 +240,7 @@ struct numeric_traits<bf8_t>
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
static constexpr int PackedSize = 1;
};
// below is sw fp8 conversion, not utilizing hw instruction

View File

@@ -223,9 +223,6 @@ struct numeric<half_t>
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<half_t>
{
@@ -241,6 +238,7 @@ struct numeric_traits<half_t>
static constexpr uint16_t NegInf = 0xFC00;
static constexpr uint16_t NaN = 0x7C01;
static constexpr uint16_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
@@ -383,4 +381,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
#endif
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
@@ -74,8 +74,6 @@ struct numeric<int8_t>
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
@@ -91,6 +89,7 @@ struct numeric_traits<int8_t>
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
#endif

View File

@@ -77,7 +77,10 @@ struct numeric
};
template <typename T>
struct numeric_traits;
struct numeric_traits
{
static constexpr int PackedSize = 1;
};
template <>
struct numeric_traits<float>
@@ -94,6 +97,7 @@ struct numeric_traits<float>
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr int PackedSize = 1;
using bitwise_type = uint32_t;
};

View File

@@ -21,8 +21,8 @@ struct pk_int4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_int4_t() : data{type{}} {}
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
};
// limits
@@ -91,6 +91,16 @@ struct numeric<pk_int4_t>
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
};
template <>
struct numeric_traits<pk_int4_t>
{
static constexpr int PackedSize = 2;
};
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);

View File

@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -30,17 +31,34 @@ struct native_t
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_, typename = void>
struct ext_vector;
template <typename T_, index_t N_>
struct ext_vector
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
using value_type = typename native_t<remove_cvref_t<T_>>::type;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
@@ -48,6 +66,17 @@ struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
@@ -55,10 +84,11 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T>
template <typename T, typename>
struct vector_traits
{
using scalar_type = remove_cvref_t<T>;
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
static constexpr index_t vector_size = 1;
};
@@ -66,7 +96,7 @@ struct vector_traits
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
using scalar_type = T;
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;
};
@@ -200,21 +230,11 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
// pk_int4_t
// using pk_int4_t
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
} // namespace ck_tile