mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
* Add Gemm fp8xint4 example and kernel, function pass. * Init Gemm_fp8xint4 Bpreshuffle * Added gemm_fp8xint4_Bpreshuffle files, function not checked yet * General fix. * fp8xint4 bpreshuffle function pass * fix. * init b preshuffle dequant in VGPR. * fix bug, function pass. * move b thread dequant copy to blockwise. * fix bug, function now passes. * modified the tile size to 256, 128x128x128. * fixed a bug. * Initial int4 moe, compile pass, function not check. * fix bug in moe_gemm1.cpp, now function pass. * test expert = 8 and function pass. * Added moe_pk_i4_gemm2, function pass. * Added b preshuffle pipeline v3 support. * fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass. * Split the blockwise pipeline for fp8xint4. * commit missing files * opt gemm2 to 2x2 wave * fix swizzle = false * update int4 moe with latest input changes. * update tile size. * enable pipeline v3. * fix nswizzle = true * commit a version for compiler debug. * Updated transfer_v3r1_gather to support pk_i4_t type. * for int4 moe2 for type_convert support. * remove some values between mfma instructions. * fix int4 moe * Updated transfer_v3r1_gather to support pk_i4_t type. * i4 support lds multiple shuffle * fixed int4 moe tflops calculation. * Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle * updated gemm2. * change int4 moe example names * fix and format code. * format. * format codes. * update fp8xint4 example tile size. * add <unordered_map> header * fixed. * format. * Added conditional compilation for int4 -> fp8 conversion kernels --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
Instructions for example_gemm_xdl
Run example_gemm_xdl
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5
Instructions for example_gemm_xdl_fp16_streamk_v3
Run example_gemm_xdl_fp16_streamk_v3
arg1: verification (0=no, 1=yes)
arg2: initialization (0=no init, 1=integer value, 2=decimal value)
arg3: time kernel (0=no, 1=yes)
arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)
arg11: Grid_size(-1 for max occupancy)
bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1}
Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal<MNPadding, RRR> BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2