mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Ck int4 moe develop (#1949)
* Add Gemm fp8xint4 example and kernel, function pass. * Init Gemm_fp8xint4 Bpreshuffle * Added gemm_fp8xint4_Bpreshuffle files, function not checked yet * General fix. * fp8xint4 bpreshuffle function pass * fix. * init b preshuffle dequant in VGPR. * fix bug, function pass. * move b thread dequant copy to blockwise. * fix bug, function now passes. * modified the tile size to 256, 128x128x128. * fixed a bug. * Initial int4 moe, compile pass, function not check. * fix bug in moe_gemm1.cpp, now function pass. * test expert = 8 and function pass. * Added moe_pk_i4_gemm2, function pass. * Added b preshuffle pipeline v3 support. * fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass. * Split the blockwise pipeline for fp8xint4. * commit missing files * opt gemm2 to 2x2 wave * fix swizzle = false * update int4 moe with latest input changes. * update tile size. * enable pipeline v3. * fix nswizzle = true * commit a version for compiler debug. * Updated transfer_v3r1_gather to support pk_i4_t type. * for int4 moe2 for type_convert support. * remove some values between mfma instructions. * fix int4 moe * Updated transfer_v3r1_gather to support pk_i4_t type. * i4 support lds multiple shuffle * fixed int4 moe tflops calculation. * Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle * updated gemm2. * change int4 moe example names * fix and format code. * format. * format codes. * update fp8xint4 example tile size. * add <unordered_map> header * fixed. * format. * Added conditional compilation for int4 -> fp8 conversion kernels --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -11,6 +11,13 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
inline __device__ int amd_assembly_and_b32(int a, int b)
|
||||
{
|
||||
int c;
|
||||
asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
|
||||
{
|
||||
int c;
|
||||
@@ -32,6 +39,54 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ float amd_assemble_cvt_f32_i4(int b)
|
||||
{
|
||||
float a;
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b));
|
||||
return a;
|
||||
}
|
||||
|
||||
inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
|
||||
{
|
||||
f8x4_t a;
|
||||
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n"
|
||||
"v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
|
||||
: "=v"(a)
|
||||
: "v"(b0), "v"(b1), "v"(b2), "v"(b3));
|
||||
return a;
|
||||
}
|
||||
|
||||
inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
{
|
||||
uint32_t i4x8 = static_cast<uint32_t>(a);
|
||||
uint32_t fp8x4_0;
|
||||
uint32_t fp8x4_1;
|
||||
float tmp_0, tmp_1, tmp_2;
|
||||
|
||||
asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
|
||||
"v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_tmp_2] "+v"(tmp_2),
|
||||
[v_dst_0] "+v"(fp8x4_0),
|
||||
[v_dst_1] "+v"(fp8x4_1),
|
||||
[v_src] "+v"(i4x8)
|
||||
:);
|
||||
|
||||
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
|
||||
Reference in New Issue
Block a user