mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Ck int4 moe develop (#1949)
* Add Gemm fp8xint4 example and kernel, function pass. * Init Gemm_fp8xint4 Bpreshuffle * Added gemm_fp8xint4_Bpreshuffle files, function not checked yet * General fix. * fp8xint4 bpreshuffle function pass * fix. * init b preshuffle dequant in VGPR. * fix bug, function pass. * move b thread dequant copy to blockwise. * fix bug, function now passes. * modified the tile size to 256, 128x128x128. * fixed a bug. * Initial int4 moe, compile pass, function not check. * fix bug in moe_gemm1.cpp, now function pass. * test expert = 8 and function pass. * Added moe_pk_i4_gemm2, function pass. * Added b preshuffle pipeline v3 support. * fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass. * Split the blockwise pipeline for fp8xint4. * commit missing files * opt gemm2 to 2x2 wave * fix swizzle = false * update int4 moe with latest input changes. * update tile size. * enable pipeline v3. * fix nswizzle = true * commit a version for compiler debug. * Updated transfer_v3r1_gather to support pk_i4_t type. * for int4 moe2 for type_convert support. * remove some values between mfma instructions. * fix int4 moe * Updated transfer_v3r1_gather to support pk_i4_t type. * i4 support lds multiple shuffle * fixed int4 moe tflops calculation. * Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle * updated gemm2. * change int4 moe example names * fix and format code. * format. * format codes. * update fp8xint4 example tile size. * add <unordered_map> header * fixed. * format. * Added conditional compilation for int4 -> fp8 conversion kernels --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -79,6 +79,24 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
|
||||
int lo = amd_assembly_and_b32(q, LO);
|
||||
int hi = amd_assembly_and_b32(q, HI);
|
||||
|
||||
float f32_0 = amd_assemble_cvt_f32_i4(lo);
|
||||
float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
|
||||
float f32_2 = amd_assemble_cvt_f32_i4(hi);
|
||||
float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
|
||||
|
||||
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
|
||||
}
|
||||
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); }
|
||||
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
@@ -142,6 +160,55 @@ struct PassThroughPack8
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
y = i4_to_fp8x8(bit_cast<int>(x));
|
||||
|
||||
#else
|
||||
// Added pk_i4_t to f8x2_fnuz_t conversion
|
||||
vector_type<f8_t, 8> dst;
|
||||
vector_type<float, 8> dst_tmp;
|
||||
vector_type<pk_i4_t, 4> src{x};
|
||||
|
||||
// pk_i4_t to float2_t conversion
|
||||
dst_tmp.template AsType<float2_t>()(Number<0>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
|
||||
|
||||
dst_tmp.template AsType<float2_t>()(Number<1>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
|
||||
|
||||
dst_tmp.template AsType<float2_t>()(Number<2>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
|
||||
|
||||
dst_tmp.template AsType<float2_t>()(Number<3>{}) =
|
||||
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
|
||||
|
||||
// float to f8_t conversion
|
||||
dst.template AsType<f8_t>()(Number<0>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<0>{}]);
|
||||
dst.template AsType<f8_t>()(Number<1>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<1>{}]);
|
||||
|
||||
dst.template AsType<f8_t>()(Number<2>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<2>{}]);
|
||||
dst.template AsType<f8_t>()(Number<3>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<3>{}]);
|
||||
|
||||
dst.template AsType<f8_t>()(Number<4>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<4>{}]);
|
||||
dst.template AsType<f8_t>()(Number<5>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<5>{}]);
|
||||
|
||||
dst.template AsType<f8_t>()(Number<6>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<6>{}]);
|
||||
dst.template AsType<f8_t>()(Number<7>{}) =
|
||||
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<7>{}]);
|
||||
|
||||
y = dst.template AsType<f8x8_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
|
||||
Reference in New Issue
Block a user