projects/composablekernel: add SwigluStep support for MoE blockscale (#6118)

## Summary
- add `swiglustep_and_mul` to the composablekernel MoE blockscale
activation enum
- implement the corresponding blockscale epilogue path for `SwigluStep`
- keep existing `silu` and `gelu` paths unchanged

## Scope
This PR covers the classic composablekernel blockscale MoE path under
`projects/composablekernel`.

This is separate from the `ck_tile` / FlatMM path being discussed in
ROCm/rocm-libraries#5992.

## Motivation
`Step-3.5-Flash-FP8` uses `SwigluStep` in its MoE MLP path. The
dependent AITER change needs native support for this activation in the
classic composablekernel MoE blockscale path.

## Validation
- patch is limited to two composablekernel files under
`projects/composablekernel`
- existing `silu` / `gelu` paths are unchanged
- dependent AITER runtime validation hit the classic CK 2-stage path
with AITER MoE enabled
This commit is contained in:
金黄色葡萄球君君
2026-04-21 15:24:48 +08:00
committed by GitHub
parent b367e98358
commit c9e8acc56a
2 changed files with 41 additions and 2 deletions

View File

@@ -28,8 +28,9 @@ namespace ck {
enum Activation
{
gelu_and_mul = 0,
silu_and_mul = 1
gelu_and_mul = 0,
silu_and_mul = 1,
swiglustep_and_mul = 2
};
template <typename ALayout,

View File

@@ -1592,6 +1592,25 @@ struct GridwiseMoeGemmBlockScale
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weight;
up = up * topk_weight;
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
gate = gate < 7.0f ? gate : 7.0f;
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
@@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weight;
up = up * topk_weight;
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
gate = gate < 7.0f ? gate : 7.0f;
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];