Ck int4 moe develop (#1949)

* Add Gemm fp8xint4 example and kernel, function pass.

* Init Gemm_fp8xint4 Bpreshuffle

* Added gemm_fp8xint4_Bpreshuffle files, function not checked yet

* General fix.

* fp8xint4 bpreshuffle function pass

* fix.

* init b preshuffle dequant in VGPR.

* fix bug, function pass.

* move b thread dequant copy to blockwise.

* fix bug, function now passes.

* modified the tile size to 256, 128x128x128.

* fixed a bug.

* Initial int4 moe, compile pass, function not check.

* fix bug in moe_gemm1.cpp, now function pass.

* test expert = 8 and function pass.

* Added moe_pk_i4_gemm2, function pass.

* Added b preshuffle pipeline v3 support.

* fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass.

* Split the blockwise pipeline for fp8xint4.

* commit missing files

* opt gemm2 to 2x2 wave

* fix swizzle = false

* update int4 moe with latest input changes.

* update tile size.

* enable pipeline v3.

* fix nswizzle = true

* commit a version for compiler debug.

* Updated transfer_v3r1_gather to support pk_i4_t type.

* for int4 moe2 for type_convert support.

* remove some values between mfma instructions.

* fix int4 moe

* Updated transfer_v3r1_gather to support pk_i4_t type.

* i4 support lds multiple shuffle

* fixed int4 moe tflops calculation.

* Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle

* updated gemm2.

* change int4 moe example names

* fix and format code.

* format.

* format codes.

* update fp8xint4 example tile size.

* add <unordered_map> header

* fixed.

* format.

* Added conditional compilation for int4 -> fp8 conversion kernels

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
Mingtao Gu
2025-03-10 11:16:44 +08:00
committed by GitHub
parent c954bd0cfa
commit 0db7c8f0b2
19 changed files with 6018 additions and 83 deletions

View File

@@ -79,6 +79,24 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
return res.template AsType<half4_t>()[Number<0>{}];
}
__device__ inline f8x4_t i4_to_f8x4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
int lo = amd_assembly_and_b32(q, LO);
int hi = amd_assembly_and_b32(q, HI);
float f32_0 = amd_assemble_cvt_f32_i4(lo);
float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
float f32_2 = amd_assemble_cvt_f32_i4(hi);
float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
}
__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); }
__device__ inline bhalf4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
@@ -142,6 +160,55 @@ struct PassThroughPack8
#endif
}
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
y = i4_to_fp8x8(bit_cast<int>(x));
#else
// Added pk_i4_t to f8x2_fnuz_t conversion
vector_type<f8_t, 8> dst;
vector_type<float, 8> dst_tmp;
vector_type<pk_i4_t, 4> src{x};
// pk_i4_t to float2_t conversion
dst_tmp.template AsType<float2_t>()(Number<0>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst_tmp.template AsType<float2_t>()(Number<1>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst_tmp.template AsType<float2_t>()(Number<2>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst_tmp.template AsType<float2_t>()(Number<3>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
// float to f8_t conversion
dst.template AsType<f8_t>()(Number<0>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<0>{}]);
dst.template AsType<f8_t>()(Number<1>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<1>{}]);
dst.template AsType<f8_t>()(Number<2>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<2>{}]);
dst.template AsType<f8_t>()(Number<3>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<3>{}]);
dst.template AsType<f8_t>()(Number<4>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<4>{}]);
dst.template AsType<f8_t>()(Number<5>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<5>{}]);
dst.template AsType<f8_t>()(Number<6>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<6>{}]);
dst.template AsType<f8_t>()(Number<7>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<7>{}]);
y = dst.template AsType<f8x8_t>()[Number<0>{}];
#endif
}
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE