mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Fix pk_int4 cast and add pk_int4 dtype in ck tile (#1854)
* Fix pk_int4 cast and add pk_int4 dtype in ck tile * fixes * Improvements * fix typo
This commit is contained in:
@@ -163,6 +163,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
// set rounding to nearest even as default for f8 conversions
|
||||
#define CK_USE_SR_F8_CONVERSION 0
|
||||
|
||||
// shuffle pk_i4 values during conversion to optimize number of binary
|
||||
// operations
|
||||
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ namespace ck {
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__host__ __device__ inline half4_t pki4_to_half4(int q)
|
||||
// Convert lower part of packed int4 -> int4 to half
|
||||
__device__ inline half4_t i4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
|
||||
{
|
||||
#if 1
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
|
||||
#else
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
|
||||
vector_type<half_t, 2> res;
|
||||
|
||||
half_t x_h = (x_u8 & 0x0f) - 8;
|
||||
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
|
||||
|
||||
res.template AsType<half_t>()(Number<0>{}) = x_l;
|
||||
res.template AsType<half_t>()(Number<1>{}) = x_h;
|
||||
|
||||
return res.template AsType<half2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
|
||||
@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
|
||||
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
vector_type<bhalf_t, 2> res;
|
||||
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
|
||||
|
||||
return res.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
@@ -159,11 +118,11 @@ struct PassThroughPack8
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<half_t, 8> result;
|
||||
|
||||
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
|
||||
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
|
||||
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
|
||||
result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);
|
||||
|
||||
y = result.template AsType<half8_t>()[Number<0>{}];
|
||||
#else
|
||||
@@ -171,13 +130,13 @@ struct PassThroughPack8
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<half2_t>()(Number<0>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<half2_t>()(Number<1>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<half2_t>()(Number<2>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<half2_t>()(Number<3>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<half8_t>()[Number<0>{}];
|
||||
#endif
|
||||
@@ -185,11 +144,11 @@ struct PassThroughPack8
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
@@ -197,13 +156,13 @@ struct PassThroughPack8
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<bhalf2_t>()(Number<0>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<1>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<2>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<3>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
@@ -219,12 +178,12 @@ struct DequantPack8
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<half_t, 8> result;
|
||||
|
||||
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
|
||||
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
|
||||
result.template AsType<half4_t>()(Number<1>{}) =
|
||||
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
|
||||
y = result.template AsType<half8_t>()[Number<0>{}];
|
||||
#else
|
||||
@@ -232,13 +191,13 @@ struct DequantPack8
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<half2_t>()(Number<0>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<half2_t>()(Number<1>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<half2_t>()(Number<2>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<half2_t>()(Number<3>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<half8_t>()[Number<0>{}];
|
||||
#endif
|
||||
@@ -260,7 +219,7 @@ struct PassThroughPack2
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "ck/utility/f8_utils.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/array.hpp"
|
||||
#include "ck/utility/amd_inline_asm.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for MI300 models
|
||||
@@ -14,6 +16,26 @@ namespace ck {
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
namespace details {
|
||||
|
||||
[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
|
||||
{
|
||||
half2_t vector_res;
|
||||
|
||||
vector_res.x = x.x + y.x;
|
||||
vector_res.y = x.y + y.y;
|
||||
|
||||
return vector_res;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
|
||||
{
|
||||
return amd_assembly_pk_add_f16(x, y);
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
@@ -520,13 +542,51 @@ template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
auto l_f32 = ck::type_convert<float>(x_l);
|
||||
auto h_f32 = ck::type_convert<float>(x_h);
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
return {l_f32, h_f32};
|
||||
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
float2_t res = {x_h, x_l};
|
||||
#elif
|
||||
float2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, pk_i4_t>(pk_i4_t x)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
#else
|
||||
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
|
||||
#endif
|
||||
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, pk_i4_t>(pk_i4_t x)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
bhalf2_t res = {type_convert<bhalf_t>(x_h), type_convert<bhalf_t>(x_l)};
|
||||
#else
|
||||
bhalf2_t res = {type_convert<bhalf_t>(x_l), type_convert<bhalf_t>(x_h)};
|
||||
#endif
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
|
||||
@@ -144,6 +144,10 @@
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
#endif
|
||||
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
|
||||
140
include/ck_tile/core/numeric/pk_int4.hpp
Normal file
140
include/ck_tile/core/numeric/pk_int4.hpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Packed 2xint4
|
||||
struct pk_int4_t
|
||||
{
|
||||
using type = int8_t;
|
||||
type data;
|
||||
__host__ __device__ constexpr pk_int4_t() : data{type{}} {}
|
||||
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_int4_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
|
||||
{
|
||||
constexpr uint8_t val = 0b10001000;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest()
|
||||
{
|
||||
constexpr uint8_t val = 0b10001000;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
|
||||
{
|
||||
constexpr uint8_t val = 0b01110111;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
fp32x2_t res = {x_h, x_l};
|
||||
#elif
|
||||
fp32x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
#elif
|
||||
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
|
||||
#endif
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
|
||||
#elif
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
__host__ fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t vector_res;
|
||||
|
||||
vector_res.x = x.x + y.x;
|
||||
vector_res.y = x.y + y.y;
|
||||
|
||||
return vector_res;
|
||||
}
|
||||
|
||||
__device__ fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t c;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
|
||||
return c;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -2,3 +2,4 @@ add_subdirectory(image_to_column)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(data_type)
|
||||
|
||||
4
test/ck_tile/data_type/CMakeLists.txt
Normal file
4
test/ck_tile/data_type/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
|
||||
endif()
|
||||
65
test/ck_tile/data_type/test_pk_int4.cpp
Normal file
65
test/ck_tile/data_type/test_pk_int4.cpp
Normal file
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
using ck_tile::bf16_t;
|
||||
using ck_tile::bf16x2_t;
|
||||
using ck_tile::fp16x2_t;
|
||||
using ck_tile::fp32x2_t;
|
||||
using ck_tile::half_t;
|
||||
using ck_tile::pk_int4_t;
|
||||
|
||||
TEST(PackedInt4, ConvertToFloat)
|
||||
{
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
constexpr float first_input_val = 7.f;
|
||||
constexpr float second_input_val = -1.f;
|
||||
#else
|
||||
constexpr float first_input_val = -1.f;
|
||||
constexpr float second_input_val = 7.f;
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
|
||||
fp32x2_t out = ck_tile::pk_int4_t_to_fp32x2_t(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
|
||||
TEST(PackedInt4, ConvertToHalf)
|
||||
{
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
const half_t first_input_val = ck_tile::type_convert<half_t>(7.f);
|
||||
const half_t second_input_val = ck_tile::type_convert<half_t>(-1.f);
|
||||
#else
|
||||
const half_t first_input_val = ck_tile::type_convert<half_t>(-1.f);
|
||||
const half_t second_input_val = ck_tile::type_convert<half_t>(7.f);
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
|
||||
fp16x2_t out = ck_tile::pk_int4_t_to_halfx2_t(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
|
||||
TEST(PackedInt4, ConvertToBHalf)
|
||||
{
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(7.f);
|
||||
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(-1.f);
|
||||
#else
|
||||
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(-1.f);
|
||||
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(7.f);
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
|
||||
bf16x2_t out = ck_tile::pk_int4_t_to_bfloat16x2_t(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
@@ -50,3 +50,4 @@ endif()
|
||||
|
||||
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
|
||||
add_gtest_executable(test_bhalf test_bhalf.cpp)
|
||||
add_gtest_executable(test_pk_i4 test_pk_i4.cpp)
|
||||
|
||||
77
test/data_type/test_pk_i4.cpp
Normal file
77
test/data_type/test_pk_i4.cpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <bitset>
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <iomanip>
|
||||
#include "gtest/gtest.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
using ck::bhalf2_t;
|
||||
using ck::bhalf_t;
|
||||
using ck::float2_t;
|
||||
using ck::half2_t;
|
||||
using ck::half4_t;
|
||||
using ck::half_t;
|
||||
using ck::pk_i4_t;
|
||||
using ck::pk_i4x4_t;
|
||||
|
||||
TEST(PackedInt4, ConvertToFloat)
|
||||
{
|
||||
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
constexpr float first_input_val = 7.f;
|
||||
constexpr float second_input_val = -1.f;
|
||||
#else
|
||||
constexpr float first_input_val = -1.f;
|
||||
constexpr float second_input_val = 7.f;
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_i4_t in = ck::bit_cast<int8_t>(data);
|
||||
float2_t out = ck::type_convert<float2_t>(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
|
||||
TEST(PackedInt4, ConvertToHalf)
|
||||
{
|
||||
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
constexpr half_t first_input_val = ck::type_convert<half_t>(7.f);
|
||||
constexpr half_t second_input_val = ck::type_convert<half_t>(-1.f);
|
||||
#else
|
||||
constexpr half_t first_input_val = ck::type_convert<half_t>(-1.f);
|
||||
constexpr half_t second_input_val = ck::type_convert<half_t>(7.f);
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_i4_t in = ck::bit_cast<int8_t>(data);
|
||||
half2_t out = ck::type_convert<half2_t>(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
|
||||
TEST(PackedInt4, ConvertToBHalf)
|
||||
{
|
||||
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
const bhalf_t first_input_val = ck::type_convert<bhalf_t>(7.f);
|
||||
const bhalf_t second_input_val = ck::type_convert<bhalf_t>(-1.f);
|
||||
#else
|
||||
const bhalf_t first_input_val = ck::type_convert<bhalf_t>(-1.f);
|
||||
const bhalf_t second_input_val = ck::type_convert<bhalf_t>(7.f);
|
||||
#endif
|
||||
uint8_t data = 0b11110111; // {-1, 7}
|
||||
pk_i4_t in = ck::bit_cast<int8_t>(data);
|
||||
bhalf2_t out = ck::type_convert<bhalf2_t>(in);
|
||||
|
||||
EXPECT_EQ(out.x, first_input_val);
|
||||
EXPECT_EQ(out.y, second_input_val);
|
||||
}
|
||||
Reference in New Issue
Block a user