mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] support split-k a16w4 gemm1 (#3389)
* initial version to support moe gemm1 split-k * add missing args * fix build warning * update reference * for split-k disable bias and weight * remove debug log * fix format * fix div by zero errors * fix cmake config * update * resolve conflicts * remove useless changes * reformat * fix * remove useless changes * fix ci --------- Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
This commit is contained in:
@@ -31,13 +31,14 @@ if(has_supported_gpu)
|
||||
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94")
|
||||
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
|
||||
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
include(mxgemm/mx_flatmm_instance.cmake)
|
||||
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
|
||||
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
|
||||
@@ -191,13 +191,15 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
std::cout << "Launching kernel " << Kernel::GetName() << "\n"
|
||||
<< "with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
<< "\n"
|
||||
<< "k_batch: " << kargs.k_batch << std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
@@ -471,10 +473,33 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_split_k")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_split_k!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_up | gemm2]");
|
||||
"[gemm1_gate_up | gemm1_split_k | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
@@ -69,7 +69,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
@@ -80,7 +80,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.")
|
||||
.insert("k_batch", "1", "parallism to control splik-k.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -67,9 +67,12 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType =
|
||||
std::conditional_t<kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k, float, PrecActType>;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
@@ -88,6 +91,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
const ck_tile::index_t k_batch = arg_parser.get_int("k_batch");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
@@ -231,14 +235,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()),
|
||||
N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
@@ -250,7 +255,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
k_batch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
@@ -85,8 +85,9 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
c_rslt_host.SetZero();
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
|
||||
@@ -25,14 +25,16 @@ using BF16 = ck_tile::bf16_t;
|
||||
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
|
||||
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
|
||||
|
||||
inline constexpr int ScaleGranularityM = 1;
|
||||
inline constexpr int ScaleGranularityN = 1;
|
||||
inline constexpr int ScaleGranularityK = 32;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
||||
|
||||
template float mx_flatmm_calc<FLATMM_CONFIG,
|
||||
A_DATA_TYPE,
|
||||
|
||||
@@ -105,10 +105,12 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
auto scale_a_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
invoke_mx_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
|
||||
Reference in New Issue
Block a user