[rocm-libraries] ROCm/rocm-libraries#8944 (commit 7be2dbb)

feat(ck): Add swiglu_oai (OAI SwiGLU) activation to XDL
 2-stage MoE epilogue. (#8944)

## Motivation

Enable the OAI-form SwiGLU activation (`swiglu_oai`, `gate *
sigmoid(1.702 * gate) * (up + 1)`, gpt-oss style) in the Composable
Kernel XDL 2-stage MoE path. The MoE gridwise kernel epilogue currently
supports only silu/gelu; this adds swiglu_oai so OAI-style MoE models
can use this path.

JIRA ID : ROCM-27213

## Technical Details

- `gridwise_gemm_xdl_cshuffle_common.hpp`: add
`Activation::swiglu_oai_and_mul = 3`.
- `gridwise_moe_gemm.hpp`: add the `apply_swiglu_oai_activation` helper
(`gate * sigmoid(1.702 * gate) * (up + 1)`, clamp `gate <= 7` and `up in
[-7, 7]`, OAI/gpt-oss form) and wire it into all 4 epilogue paths (quant
+ non-quant x `Run` / `Run_2Lds`).
- The activation is applied in fp32 in the epilogue and is orthogonal to
the GEMM compute (MFMA/tile/pipeline untouched) and to quantization
(existing per-token dequant reused). Only the non-blockscale gridwise
kernel is changed.
- Consumed by aiter via ROCm/aiter#3886 (dispatch + codegen);
review/merge together.

## Test Plan

Validate the new epilogue branch against a torch fp32 OAI-SwiGLU
reference through the aiter per-token fp8 MoE path (op-isolate on gfx942
/ MI308X).

## Test Result

cos_sim = 0.999993 vs the torch fp32 OAI-SwiGLU reference; no NaN.
Confirmed the per-token fp8 path dispatches to this `GridwiseMoeGemm`
kernel (rocprofv3) and runs the swiglu_oai epilogue branch.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-07-01 12:36:31 +00:00
committed by assistant-librarian[bot]
parent a65244b86c
commit f4e6fad973
6 changed files with 189 additions and 1 deletions

View File

@@ -30,9 +30,28 @@ enum Activation
{
gelu_and_mul = 0,
silu_and_mul = 1,
swiglustep_and_mul = 2
swiglustep_and_mul = 2,
swiglu_oai_and_mul = 3
};
// OAI / gpt-oss SwiGLU activation: gate * sigmoid(alpha * gate) * (up + 1), with a
// pre-activation clamp (gate upper-bounded to limit, up symmetric in [-limit, limit]).
// Defaults limit = 7.0, alpha = 1.702 per gpt-oss. Same math as ck_tile::moe::Swiglu, but
// the sigmoid here uses a plain fp32 division (not __builtin_amdgcn_rcpf), so results match
// to ~1e-6 rather than bit-exact. Single source of truth shared by the MoE kernel epilogue
// (swiglu_oai_and_mul) and its host unit test.
// Not constexpr: it always calls math::exp (non-constexpr), so it can never be evaluated in a
// constant expression. inline keeps it ODR-safe in this header. The kernel epilogue and the host
// unit test only call it at runtime.
__host__ __device__ inline float
swiglu_oai(float gate, float up, float limit = 7.0f, float alpha = 1.702f)
{
gate = math::min(gate, limit); // gate <= limit
up = math::min(math::max(up, -limit), limit); // up in [-limit, limit]
const float sig = 1.0f / (1.0f + math::exp(alpha * -gate)); // sigmoid(alpha * gate)
return gate * sig * (up + 1.0f); // OAI form
}
template <typename ALayout,
typename BLayout,
typename ELayout,

View File

@@ -297,6 +297,25 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp);
return gate * up;
}
// Clamp limit for swiglu_oai_and_mul (gpt-oss / OAI form): gate clamped to <= L,
// up clamped to [-L, L]; L hardcoded to 7.0. alpha = 1.702 per gpt-oss default.
static constexpr float kSwiGluOaiLimit = 7.0f;
static constexpr float kSwiGluOaiAlpha = 1.702f;
// Helper: apply OAI SwiGLU activation gate*sigmoid(alpha*gate)*(up+1) with pre-activation
// clamp (gate upper-bounded, up symmetric). Same math as ck_tile::moe::Swiglu in
// ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp, but the sigmoid here uses a plain fp32
// division (not __builtin_amdgcn_rcpf), so results match to ~1e-6 rather than bit-exact.
// Distinct from swiglustep (no +1, no alpha).
// Used by all four swiglu_oai_and_mul epilogue paths (quant/non-quant x pipeline-A/B).
__host__ __device__ static constexpr float apply_swiglu_oai_activation(float gate, float up)
{
// Delegate to the shared ck::swiglu_oai (defined in gridwise_gemm_xdl_cshuffle_common.hpp)
// so the kernel epilogue and the host unit test exercise a single implementation.
return ck::swiglu_oai(gate, up, kSwiGluOaiLimit, kSwiGluOaiAlpha);
}
using mfma_selector =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>;
@@ -1506,6 +1525,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
@@ -1579,6 +1618,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
@@ -2011,6 +2062,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{
@@ -2084,6 +2155,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
else if constexpr(ActivationOperation ==
Activation::swiglu_oai_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
}
}
else
{

View File

@@ -1140,6 +1140,9 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
static_assert(ActivationOperation != Activation::swiglu_oai_and_mul,
"gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the "
"non-blockscale gridwise_moe_gemm.");
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
@@ -1694,6 +1697,9 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
static_assert(ActivationOperation != Activation::swiglu_oai_and_mul,
"gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the "
"non-blockscale gridwise_moe_gemm.");
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1

View File

@@ -350,3 +350,4 @@ add_subdirectory(synchronization)
add_subdirectory(gpu_reference)
add_subdirectory(util)
add_subdirectory(gpu_verification)
add_subdirectory(swiglu_oai_activation)

View File

@@ -0,0 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_gtest_executable(test_swiglu_oai_activation test_swiglu_oai_activation.cpp)
if(result EQUAL 0)
target_link_libraries(test_swiglu_oai_activation PRIVATE utility)
endif()

View File

@@ -0,0 +1,72 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <algorithm>
#include <cmath>
#include "gtest/gtest.h"
// ck::swiglu_oai is the single source of truth shared by the XDL 2-stage MoE epilogue
// (Activation::swiglu_oai_and_mul) and this host unit test.
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp"
namespace {
// Independent fp64 reference for the OAI / gpt-oss SwiGLU activation:
// gate * sigmoid(alpha * gate) * (up + 1)
// with gate clamped to <= limit and up clamped to [-limit, limit]. Computed in double so
// we are not comparing the fp32 implementation under test against a copy of itself.
double ref_swiglu_oai(double gate, double up, double limit = 7.0, double alpha = 1.702)
{
gate = std::min(gate, limit);
up = std::min(std::max(up, -limit), limit);
const double s = 1.0 / (1.0 + std::exp(alpha * -gate));
return gate * s * (up + 1.0);
}
constexpr float kTol = 1e-3f;
} // namespace
// Values match the reference when no clamping is active.
TEST(SwigluOai, MatchesReferenceInRange)
{
const float pts[][2] = {{0.f, 0.f}, {1.f, 2.f}, {-1.5f, 0.5f}, {3.f, -2.f}, {-4.f, -3.f}};
for(const auto& p : pts)
{
const float got = ck::swiglu_oai(p[0], p[1]);
const float ref = static_cast<float>(ref_swiglu_oai(p[0], p[1]));
EXPECT_NEAR(got, ref, kTol) << "gate=" << p[0] << " up=" << p[1];
}
}
// gate is upper-bounded to limit (7); it is NOT lower-clamped.
TEST(SwigluOai, GateUpperClamp)
{
EXPECT_NEAR(ck::swiglu_oai(100.f, 1.f), static_cast<float>(ref_swiglu_oai(7.0, 1.0)), kTol);
// gate >= limit all saturate to the same value.
EXPECT_NEAR(ck::swiglu_oai(7.f, 1.f), ck::swiglu_oai(100.f, 1.f), kTol);
// large negative gate passes through unclamped.
EXPECT_NEAR(ck::swiglu_oai(-50.f, 1.f), static_cast<float>(ref_swiglu_oai(-50.0, 1.0)), kTol);
}
// up is symmetric-clamped to [-7, 7].
TEST(SwigluOai, UpSymmetricClamp)
{
EXPECT_NEAR(ck::swiglu_oai(1.f, 100.f), static_cast<float>(ref_swiglu_oai(1.0, 7.0)), kTol);
EXPECT_NEAR(ck::swiglu_oai(1.f, -100.f), static_cast<float>(ref_swiglu_oai(1.0, -7.0)), kTol);
}
// The "+1" shift on up is part of the OAI form: up == -1 zeroes the output exactly.
TEST(SwigluOai, UpPlusOneShiftZeroesOutput)
{
EXPECT_FLOAT_EQ(ck::swiglu_oai(2.f, -1.f), 0.f);
EXPECT_FLOAT_EQ(ck::swiglu_oai(-3.f, -1.f), 0.f);
}
// alpha defaults to 1.702 (gpt-oss); passing it explicitly must not change the result.
TEST(SwigluOai, DefaultAlphaMatchesExplicit)
{
EXPECT_FLOAT_EQ(ck::swiglu_oai(2.f, 3.f), ck::swiglu_oai(2.f, 3.f, 7.0f, 1.702f));
}