mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
support swiglu activaion and use rcpf to accelerate silu
This commit is contained in:
@@ -174,9 +174,14 @@ float a16w4_moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
|
||||
using FusedAct =
|
||||
std::conditional_t<MXFP4_Pipeline, ck_tile::moe::Swiglu, ck_tile::moe::MoeSilu>;
|
||||
|
||||
using Kernel = ck_tile::
|
||||
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
|
||||
using Kernel = ck_tile::MoeFlatmmKernel<TilePartitioner,
|
||||
CodegenFlatmmPipeline,
|
||||
GemmEpilogue,
|
||||
moe_kind,
|
||||
FusedAct>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -449,29 +454,6 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_gate_only")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_only!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
@@ -498,7 +480,7 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
|
||||
"[gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -68,17 +68,18 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_only",
|
||||
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_only by default.")
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
|
||||
@@ -47,21 +47,6 @@ float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct LocalSilu
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(const T& x) const
|
||||
{
|
||||
T y;
|
||||
ck_tile::element_wise::Silu{}(y, x);
|
||||
return y;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
@@ -314,16 +299,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::Swiglu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
|
||||
@@ -45,21 +45,6 @@ float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct LocalSilu
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(const T& x) const
|
||||
{
|
||||
T y;
|
||||
ck_tile::element_wise::Silu{}(y, x);
|
||||
return y;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
@@ -290,16 +275,15 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::MoeSilu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
|
||||
Reference in New Issue
Block a user