mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
projects/composablekernel: add SwigluStep support for MoE blockscale (#6118)
## Summary - add `swiglustep_and_mul` to the composablekernel MoE blockscale activation enum - implement the corresponding blockscale epilogue path for `SwigluStep` - keep existing `silu` and `gelu` paths unchanged ## Scope This PR covers the classic composablekernel blockscale MoE path under `projects/composablekernel`. This is separate from the `ck_tile` / FlatMM path being discussed in ROCm/rocm-libraries#5992. ## Motivation `Step-3.5-Flash-FP8` uses `SwigluStep` in its MoE MLP path. The dependent AITER change needs native support for this activation in the classic composablekernel MoE blockscale path. ## Validation - patch is limited to two composablekernel files under `projects/composablekernel` - existing `silu` / `gelu` paths are unchanged - dependent AITER runtime validation hit the classic CK 2-stage path with AITER MoE enabled
This commit is contained in:
@@ -28,8 +28,9 @@ namespace ck {
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1,
|
||||
swiglustep_and_mul = 2
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
|
||||
@@ -1592,6 +1592,25 @@ struct GridwiseMoeGemmBlockScale
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
@@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
|
||||
Reference in New Issue
Block a user