support swiglu activaion and use rcpf to accelerate silu

This commit is contained in:
Feng Shijie
2025-08-26 12:32:29 +00:00
parent d05eed931d
commit 65b702454c
8 changed files with 376 additions and 350 deletions

View File

@@ -174,9 +174,14 @@ float a16w4_moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
MXFP4_Pipeline,
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
using FusedAct =
std::conditional_t<MXFP4_Pipeline, ck_tile::moe::Swiglu, ck_tile::moe::MoeSilu>;
using Kernel = ck_tile::
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
using Kernel = ck_tile::MoeFlatmmKernel<TilePartitioner,
CodegenFlatmmPipeline,
GemmEpilogue,
moe_kind,
FusedAct>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -449,29 +454,6 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm1_gate_only")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_only!");
}
}
else if(gemm_kind == "gemm2")
{
if(mixed_prec == "fp16xfp4")
@@ -498,7 +480,7 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
"[gemm1_gate_up | gemm2]");
}
}
else

View File

@@ -68,17 +68,18 @@ auto create_args(int argc, char* argv[])
.insert("b_layout", "C", "B tensor data layout - Col by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_only",
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
"gemm1_gate_only by default.")
"gemm1_gate_up",
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
"gemm1_gate_up by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("mixed_prec",
"bf16xfp4",
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
.insert("init", "0", "0:random, 1:constant(1)")
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
.insert("warp_tile",
"0",
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
bool result = arg_parser.parse(argc, argv);

View File

@@ -47,21 +47,6 @@ float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
return ave_time;
}
namespace {
struct LocalSilu
{
template <typename T>
CK_TILE_HOST_DEVICE T operator()(const T& x) const
{
T y;
ck_tile::element_wise::Silu{}(y, x);
return y;
};
};
} // namespace
template <typename PrecActType,
typename PrecWeightType,
typename FlatmmConfig,
@@ -314,16 +299,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
ck_tile::reference_moe_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
ck_tile::moe::Swiglu>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,

View File

@@ -45,21 +45,6 @@ float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
return ave_time;
}
namespace {
struct LocalSilu
{
template <typename T>
CK_TILE_HOST_DEVICE T operator()(const T& x) const
{
T y;
ck_tile::element_wise::Silu{}(y, x);
return y;
};
};
} // namespace
template <typename PrecType,
typename FlatmmConfig,
ck_tile::MoeFlatmmKind kind,
@@ -290,16 +275,15 @@ int run_moe_gemm_example_with_layouts(int argc,
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
ck_tile::reference_moe_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
ck_tile::moe::MoeSilu>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,