mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762)
* add a prototype of int4 * clean * debug * clean * clean * move packed into dynamic_buffer * fixed coord reset * add fast pki4 to half conversion * fix * fixed reference and host_tensor * fixed tensor init * format * debug i4_to_f16_convert * format * fixed splitk * weight permute * add b tile permute * clean * weight permute with splitki * format * improve weight layout * add and_or_b32 * fixed splitk crush * add permute switch as a template * recover v3r1 * clean * failure with intrawave v2 * fixed * fixed * add ckProfiler * add bfp16 support * add bf16 example * fixed int4 to bhalf_t conversion * format * fixed int4 to bf16 conversion * clean * add instances for mem * clean * fixed host tensor size * fixed * debug * fixed * add pk_i4_t as a struct * fix * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * revert * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed comments * revert * clean * revert * revert * fixed * Update CMakeLists.txt * Update script/cmake-ck-dev.sh Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update CMakeLists.txt Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed * fixed * fixed * revert * revert * add comments * format * fixed assert * fixed * Fix I4 define in ckProfiler * Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -7,12 +7,177 @@
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/amd_inline_asm.hpp"
|
||||
#include <cassert>
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Fast int4x4 to half8_t data type conversion based on paper
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__host__ __device__ inline half4_t pki4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
|
||||
// Extract the two int4 at low bit and create two fp16 number.
|
||||
int lo = amd_assembly_and_or_b32(q, LO, EX);
|
||||
// Extract the two int4 at hight bit and create two fp16 number.
|
||||
int hi = amd_assembly_and_or_b32(q, HI, EX);
|
||||
|
||||
const int SUB = 0xE408E408; // half2 {-1032, -1032}
|
||||
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
|
||||
const int ADD = 0xd480d480; // half2 {-72, -72}
|
||||
|
||||
vector_type<half_t, 4> res;
|
||||
|
||||
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
|
||||
res.template AsType<half2_t>()(Number<0>{}) =
|
||||
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
|
||||
|
||||
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
|
||||
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
|
||||
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
|
||||
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
|
||||
{
|
||||
#if 1
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
|
||||
#else
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
|
||||
vector_type<half_t, 2> res;
|
||||
|
||||
half_t x_h = (x_u8 & 0x0f) - 8;
|
||||
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
|
||||
|
||||
res.template AsType<half_t>()(Number<0>{}) = x_l;
|
||||
res.template AsType<half_t>()(Number<1>{}) = x_h;
|
||||
|
||||
return res.template AsType<half2_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388616.f;
|
||||
fp32_intermediates[1] -= 8388616.f;
|
||||
fp32_intermediates[2] -= 8388616.f;
|
||||
fp32_intermediates[3] -= 8388616.f;
|
||||
|
||||
vector_type<bhalf_t, 4> res;
|
||||
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(
|
||||
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
|
||||
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
|
||||
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
|
||||
|
||||
return res.template AsType<bhalf4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
|
||||
|
||||
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
vector_type<bhalf_t, 2> res;
|
||||
|
||||
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
|
||||
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
|
||||
|
||||
return res.template AsType<bhalf2_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
vector_type<half_t, 8> result;
|
||||
|
||||
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
|
||||
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
|
||||
|
||||
y = result.template AsType<half8_t>()[Number<0>{}];
|
||||
#else
|
||||
vector_type<half_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<half2_t>()(Number<0>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<half2_t>()(Number<1>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<half2_t>()(Number<2>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<half2_t>()(Number<3>{}) =
|
||||
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<half8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
vector_type<bhalf_t, 8> result;
|
||||
|
||||
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
|
||||
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
|
||||
y = result.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#else
|
||||
vector_type<bhalf_t, 8> dst;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
dst.template AsType<bhalf2_t>()(Number<0>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<1>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<2>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
dst.template AsType<bhalf2_t>()(Number<3>{}) =
|
||||
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
|
||||
struct UnaryOpBase
|
||||
@@ -49,6 +214,24 @@ struct PassThroughPack2
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
|
||||
{
|
||||
#if 1
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
|
||||
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
|
||||
|
||||
y = {l_f16, h_f16};
|
||||
#else
|
||||
uint32_t t = ck::bit_cast<uint8_t>(x);
|
||||
y = ck::bit_cast<half2_t>(t);
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
@@ -76,6 +259,12 @@ struct PassThrough final : public UnaryOpBase
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user