[CK_TILE] support split-k a16w4 gemm1 (#3389)

* initial version to support moe gemm1 split-k

* add missing args

* fix build warning

* update reference

* for split-k disable bias and weight

* remove debug log

* fix format

* fix div by zero errors

* fix cmake config

* update

* resolve conflicts

* remove useless changes

* reformat

* fix

* remove useless changes

* fix ci

---------

Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
This commit is contained in:
yadaish
2025-12-29 23:05:35 +08:00
committed by GitHub
parent a0acc83a72
commit dae85ead64
11 changed files with 136 additions and 78 deletions

View File

@@ -31,13 +31,14 @@ if(has_supported_gpu)
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
if (GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94")
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
endif()
if (GPU_TARGETS MATCHES "gfx95")
include(mxgemm/mx_flatmm_instance.cmake)
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")

View File

@@ -8,7 +8,7 @@
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;

View File

@@ -191,13 +191,15 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
std::cout << "Launching kernel " << Kernel::GetName() << "\n"
<< "with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
<< "\n"
<< "k_batch: " << kargs.k_batch << std::endl;
}
if(s.flush_cache_)
@@ -471,10 +473,33 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
throw std::runtime_error("Unsupported precision type for gemm2!");
}
}
else if(gemm_kind == "gemm1_split_k")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_split_k!");
}
}
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_up | gemm2]");
"[gemm1_gate_up | gemm1_split_k | gemm2]");
}
}
else

View File

@@ -13,7 +13,7 @@
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
@@ -69,7 +69,7 @@ auto create_args(int argc, char* argv[])
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_up",
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
"Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - "
"gemm1_gate_up by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
@@ -80,7 +80,8 @@ auto create_args(int argc, char* argv[])
.insert("warp_tile",
"0",
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
.insert("repeat", "10", "number of iterations to benchmark the kernel.")
.insert("k_batch", "1", "parallism to control splik-k.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -67,9 +67,12 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
return -1;
};
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType = PrecActType;
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType =
std::conditional_t<kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k, float, PrecActType>;
using AccDataType = float;
using ScaleType = ck_tile::e8m0_t;
@@ -88,6 +91,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
const ck_tile::index_t k_batch = arg_parser.get_int("k_batch");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
@@ -231,14 +235,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
auto scale_b_shuffle_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()),
N / ScaleGranularityN};
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
ck_tile::FlatmmScalePointer<-1>,
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>,
ck_tile::FlatmmScalePointer<1>>;
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
p_sorted_expert_weight_dev,
@@ -250,7 +255,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
num_tokens,
experts,
topk,
1, // k_batch
k_batch, // k_batch
M,
N,
K,

View File

@@ -85,8 +85,9 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_rslt_host.SetZero();
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
auto scale_b_dev_ptr =
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
invoke_mixed_prec_flatmm<FlatmmConfig,
ADataType,

View File

@@ -25,14 +25,16 @@ using BF16 = ck_tile::bf16_t;
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
using ScaleType = ck_tile::e8m0_t;
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
inline constexpr int ScaleGranularityM = 1;
inline constexpr int ScaleGranularityN = 1;
inline constexpr int ScaleGranularityK = 32;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
template float mx_flatmm_calc<FLATMM_CONFIG,
A_DATA_TYPE,

View File

@@ -105,10 +105,12 @@ int run_mx_flatmm_with_layouts(int argc,
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
auto scale_a_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
invoke_mx_flatmm<FlatmmConfig,
ADataType,