mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement the fp16xint4 scale weight only kernel for Ali (#1786)
* enable int4 scale (weight only) kernel * format some files * Add unit test for int4 weight only * fixed and formatted code * fixed * formated * formated * fixed * fixed a bug in the ckProfiler, and formatted the code --------- Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -44,6 +44,40 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
|
||||
// Extract the two int4 at low bit and create two fp16 number.
|
||||
int lo = amd_assembly_and_or_b32(q, LO, EX);
|
||||
// Extract the two int4 at hight bit and create two fp16 number.
|
||||
int hi = amd_assembly_and_or_b32(q, HI, EX);
|
||||
|
||||
const int SUB = 0xE408E408; // half2 {-1032, -1032}
|
||||
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
|
||||
const int ADD = 0xd480d480; // half2 {-72, -72}
|
||||
|
||||
vector_type<half_t, 4> res;
|
||||
|
||||
res.template AsType<half2_t>()(Number<0>{}) =
|
||||
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
|
||||
|
||||
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
|
||||
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
|
||||
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2"
|
||||
: "=v"(res.template AsType<half2_t>()(Number<0>{}))
|
||||
: "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
|
||||
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2"
|
||||
: "=v"(res.template AsType<half2_t>()(Number<1>{}))
|
||||
: "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
|
||||
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
|
||||
{
|
||||
#if 1
|
||||
@@ -171,7 +205,42 @@ struct PassThroughPack8
|
||||
dst.template AsType<bhalf2_t>()(Number<3>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
template <typename Y, typename X, typename Z>
|
||||
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
|
||||
{
|
||||
#if 1
|
||||
vector_type<half_t, 8> result;
|
||||
|
||||
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
|
||||
result.template AsType<half4_t>()(Number<1>{}) =
|
||||
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
|
||||
y = result.template AsType<half8_t>()[Number<0>{}];
|
||||
#else
|
||||
vector_type<half_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<half2_t>()(Number<0>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<half2_t>()(Number<1>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<half2_t>()(Number<2>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<half2_t>()(Number<3>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<half8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user