Ck int4 moe develop (#1949)

* Add Gemm fp8xint4 example and kernel, function pass.

* Init Gemm_fp8xint4 Bpreshuffle

* Added gemm_fp8xint4_Bpreshuffle files, function not checked yet

* General fix.

* fp8xint4 bpreshuffle function pass

* fix.

* init b preshuffle dequant in VGPR.

* fix bug, function pass.

* move b thread dequant copy to blockwise.

* fix bug, function now passes.

* modified the tile size to 256, 128x128x128.

* fixed a bug.

* Initial int4 moe, compile pass, function not check.

* fix bug in moe_gemm1.cpp, now function pass.

* test expert = 8 and function pass.

* Added moe_pk_i4_gemm2, function pass.

* Added b preshuffle pipeline v3 support.

* fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass.

* Split the blockwise pipeline for fp8xint4.

* commit missing files

* opt gemm2 to 2x2 wave

* fix swizzle = false

* update int4 moe with latest input changes.

* update tile size.

* enable pipeline v3.

* fix nswizzle = true

* commit a version for compiler debug.

* Updated transfer_v3r1_gather to support pk_i4_t type.

* for int4 moe2 for type_convert support.

* remove some values between mfma instructions.

* fix int4 moe

* Updated transfer_v3r1_gather to support pk_i4_t type.

* i4 support lds multiple shuffle

* fixed int4 moe tflops calculation.

* Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle

* updated gemm2.

* change int4 moe example names

* fix and format code.

* format.

* format codes.

* update fp8xint4 example tile size.

* add <unordered_map> header

* fixed.

* format.

* Added conditional compilation for int4 -> fp8 conversion kernels

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
Mingtao Gu
2025-03-10 11:16:44 +08:00
committed by GitHub
parent c954bd0cfa
commit 0db7c8f0b2
19 changed files with 6018 additions and 83 deletions

View File

@@ -11,6 +11,13 @@
namespace ck {
inline __device__ int amd_assembly_and_b32(int a, int b)
{
int c;
asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
int c;
@@ -32,6 +39,54 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
return c;
}
inline __device__ float amd_assemble_cvt_f32_i4(int b)
{
float a;
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b));
return a;
}
inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
{
f8x4_t a;
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n"
"v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
: "=v"(a)
: "v"(b0), "v"(b1), "v"(b2), "v"(b3));
return a;
}
inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
{
uint32_t i4x8 = static_cast<uint32_t>(a);
uint32_t fp8x4_0;
uint32_t fp8x4_1;
float tmp_0, tmp_1, tmp_2;
asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
"v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_tmp_2] "+v"(tmp_2),
[v_dst_0] "+v"(fp8x4_0),
[v_dst_1] "+v"(fp8x4_1),
[v_src] "+v"(i4x8)
:);
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)