mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Fix pk_int4 cast and add pk_int4 dtype in ck tile (#1854)
* Fix pk_int4 cast and add pk_int4 dtype in ck tile * fixes * Improvements * fix typo
This commit is contained in:
@@ -16,7 +16,8 @@ namespace ck {
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__host__ __device__ inline half4_t pki4_to_half4(int q)
|
||||
// Convert lower part of packed int4 -> int4 to half
|
||||
__device__ inline half4_t i4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
|
||||
{
|
||||
#if 1
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
|
||||
#else
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
|
||||
vector_type<half_t, 2> res;
|
||||
|
||||
half_t x_h = (x_u8 & 0x0f) - 8;
|
||||
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
|
||||
|
||||
res.template AsType<half_t>()(Number<0>{}) = x_l;
|
||||
res.template AsType<half_t>()(Number<1>{}) = x_h;
|
||||
|
||||
return res.template AsType<half2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
|
||||
@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
|
||||
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
vector_type<bhalf_t, 2> res;
|
||||
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
|
||||
|
||||
return res.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
@@ -159,11 +118,11 @@ struct PassThroughPack8
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<half_t, 8> result;
|
||||
|
||||
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
|
||||
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
|
||||
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
|
||||
result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);
|
||||
|
||||
y = result.template AsType<half8_t>()[Number<0>{}];
|
||||
#else
|
||||
@@ -171,13 +130,13 @@ struct PassThroughPack8
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<half2_t>()(Number<0>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<half2_t>()(Number<1>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<half2_t>()(Number<2>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<half2_t>()(Number<3>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<half8_t>()[Number<0>{}];
|
||||
#endif
|
||||
@@ -185,11 +144,11 @@ struct PassThroughPack8
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
@@ -197,13 +156,13 @@ struct PassThroughPack8
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<bhalf2_t>()(Number<0>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<1>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<2>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<3>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
@@ -219,12 +178,12 @@ struct DequantPack8
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
vector_type<half_t, 8> result;
|
||||
|
||||
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
|
||||
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
|
||||
result.template AsType<half4_t>()(Number<1>{}) =
|
||||
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
|
||||
y = result.template AsType<half8_t>()[Number<0>{}];
|
||||
#else
|
||||
@@ -232,13 +191,13 @@ struct DequantPack8
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<half2_t>()(Number<0>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<half2_t>()(Number<1>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<half2_t>()(Number<2>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<half2_t>()(Number<3>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<half8_t>()[Number<0>{}];
|
||||
#endif
|
||||
@@ -260,7 +219,7 @@ struct PassThroughPack2
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
Reference in New Issue
Block a user