mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
[rocm-libraries] ROCm/rocm-libraries#8944 (commit 7be2dbb)
feat(ck): Add swiglu_oai (OAI SwiGLU) activation to XDL 2-stage MoE epilogue. (#8944) ## Motivation Enable the OAI-form SwiGLU activation (`swiglu_oai`, `gate * sigmoid(1.702 * gate) * (up + 1)`, gpt-oss style) in the Composable Kernel XDL 2-stage MoE path. The MoE gridwise kernel epilogue currently supports only silu/gelu; this adds swiglu_oai so OAI-style MoE models can use this path. JIRA ID : ROCM-27213 ## Technical Details - `gridwise_gemm_xdl_cshuffle_common.hpp`: add `Activation::swiglu_oai_and_mul = 3`. - `gridwise_moe_gemm.hpp`: add the `apply_swiglu_oai_activation` helper (`gate * sigmoid(1.702 * gate) * (up + 1)`, clamp `gate <= 7` and `up in [-7, 7]`, OAI/gpt-oss form) and wire it into all 4 epilogue paths (quant + non-quant x `Run` / `Run_2Lds`). - The activation is applied in fp32 in the epilogue and is orthogonal to the GEMM compute (MFMA/tile/pipeline untouched) and to quantization (existing per-token dequant reused). Only the non-blockscale gridwise kernel is changed. - Consumed by aiter via ROCm/aiter#3886 (dispatch + codegen); review/merge together. ## Test Plan Validate the new epilogue branch against a torch fp32 OAI-SwiGLU reference through the aiter per-token fp8 MoE path (op-isolate on gfx942 / MI308X). ## Test Result cos_sim = 0.999993 vs the torch fp32 OAI-SwiGLU reference; no NaN. Confirmed the per-token fp8 path dispatches to this `GridwiseMoeGemm` kernel (rocprofv3) and runs the swiglu_oai epilogue branch. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a65244b86c
commit
f4e6fad973
@@ -30,9 +30,28 @@ enum Activation
|
||||
{
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1,
|
||||
swiglustep_and_mul = 2
|
||||
swiglustep_and_mul = 2,
|
||||
swiglu_oai_and_mul = 3
|
||||
};
|
||||
|
||||
// OAI / gpt-oss SwiGLU activation: gate * sigmoid(alpha * gate) * (up + 1), with a
|
||||
// pre-activation clamp (gate upper-bounded to limit, up symmetric in [-limit, limit]).
|
||||
// Defaults limit = 7.0, alpha = 1.702 per gpt-oss. Same math as ck_tile::moe::Swiglu, but
|
||||
// the sigmoid here uses a plain fp32 division (not __builtin_amdgcn_rcpf), so results match
|
||||
// to ~1e-6 rather than bit-exact. Single source of truth shared by the MoE kernel epilogue
|
||||
// (swiglu_oai_and_mul) and its host unit test.
|
||||
// Not constexpr: it always calls math::exp (non-constexpr), so it can never be evaluated in a
|
||||
// constant expression. inline keeps it ODR-safe in this header. The kernel epilogue and the host
|
||||
// unit test only call it at runtime.
|
||||
__host__ __device__ inline float
|
||||
swiglu_oai(float gate, float up, float limit = 7.0f, float alpha = 1.702f)
|
||||
{
|
||||
gate = math::min(gate, limit); // gate <= limit
|
||||
up = math::min(math::max(up, -limit), limit); // up in [-limit, limit]
|
||||
const float sig = 1.0f / (1.0f + math::exp(alpha * -gate)); // sigmoid(alpha * gate)
|
||||
return gate * sig * (up + 1.0f); // OAI form
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
|
||||
@@ -297,6 +297,25 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp);
|
||||
return gate * up;
|
||||
}
|
||||
|
||||
// Clamp limit for swiglu_oai_and_mul (gpt-oss / OAI form): gate clamped to <= L,
|
||||
// up clamped to [-L, L]; L hardcoded to 7.0. alpha = 1.702 per gpt-oss default.
|
||||
static constexpr float kSwiGluOaiLimit = 7.0f;
|
||||
static constexpr float kSwiGluOaiAlpha = 1.702f;
|
||||
|
||||
// Helper: apply OAI SwiGLU activation gate*sigmoid(alpha*gate)*(up+1) with pre-activation
|
||||
// clamp (gate upper-bounded, up symmetric). Same math as ck_tile::moe::Swiglu in
|
||||
// ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp, but the sigmoid here uses a plain fp32
|
||||
// division (not __builtin_amdgcn_rcpf), so results match to ~1e-6 rather than bit-exact.
|
||||
// Distinct from swiglustep (no +1, no alpha).
|
||||
// Used by all four swiglu_oai_and_mul epilogue paths (quant/non-quant x pipeline-A/B).
|
||||
__host__ __device__ static constexpr float apply_swiglu_oai_activation(float gate, float up)
|
||||
{
|
||||
// Delegate to the shared ck::swiglu_oai (defined in gridwise_gemm_xdl_cshuffle_common.hpp)
|
||||
// so the kernel epilogue and the host unit test exercise a single implementation.
|
||||
return ck::swiglu_oai(gate, up, kSwiGluOaiLimit, kSwiGluOaiAlpha);
|
||||
}
|
||||
|
||||
using mfma_selector =
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>;
|
||||
|
||||
@@ -1506,6 +1525,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglu_oai_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1579,6 +1618,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglu_oai_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -2011,6 +2062,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglu_oai_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -2084,6 +2155,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglu_oai_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -1140,6 +1140,9 @@ struct GridwiseMoeGemmBlockScale
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
static_assert(ActivationOperation != Activation::swiglu_oai_and_mul,
|
||||
"gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the "
|
||||
"non-blockscale gridwise_moe_gemm.");
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
constexpr auto b_coherence_flag = NonTemporalLoadB
|
||||
? AmdBufferCoherenceEnum::WAVE_NT1
|
||||
@@ -1694,6 +1697,9 @@ struct GridwiseMoeGemmBlockScale
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
static_assert(ActivationOperation != Activation::swiglu_oai_and_mul,
|
||||
"gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the "
|
||||
"non-blockscale gridwise_moe_gemm.");
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
constexpr auto b_coherence_flag = NonTemporalLoadB
|
||||
? AmdBufferCoherenceEnum::WAVE_NT1
|
||||
|
||||
@@ -350,3 +350,4 @@ add_subdirectory(synchronization)
|
||||
add_subdirectory(gpu_reference)
|
||||
add_subdirectory(util)
|
||||
add_subdirectory(gpu_verification)
|
||||
add_subdirectory(swiglu_oai_activation)
|
||||
|
||||
7
test/swiglu_oai_activation/CMakeLists.txt
Normal file
7
test/swiglu_oai_activation/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_swiglu_oai_activation test_swiglu_oai_activation.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_swiglu_oai_activation PRIVATE utility)
|
||||
endif()
|
||||
72
test/swiglu_oai_activation/test_swiglu_oai_activation.cpp
Normal file
72
test/swiglu_oai_activation/test_swiglu_oai_activation.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
// ck::swiglu_oai is the single source of truth shared by the XDL 2-stage MoE epilogue
|
||||
// (Activation::swiglu_oai_and_mul) and this host unit test.
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
// Independent fp64 reference for the OAI / gpt-oss SwiGLU activation:
|
||||
// gate * sigmoid(alpha * gate) * (up + 1)
|
||||
// with gate clamped to <= limit and up clamped to [-limit, limit]. Computed in double so
|
||||
// we are not comparing the fp32 implementation under test against a copy of itself.
|
||||
double ref_swiglu_oai(double gate, double up, double limit = 7.0, double alpha = 1.702)
|
||||
{
|
||||
gate = std::min(gate, limit);
|
||||
up = std::min(std::max(up, -limit), limit);
|
||||
const double s = 1.0 / (1.0 + std::exp(alpha * -gate));
|
||||
return gate * s * (up + 1.0);
|
||||
}
|
||||
|
||||
constexpr float kTol = 1e-3f;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Values match the reference when no clamping is active.
|
||||
TEST(SwigluOai, MatchesReferenceInRange)
|
||||
{
|
||||
const float pts[][2] = {{0.f, 0.f}, {1.f, 2.f}, {-1.5f, 0.5f}, {3.f, -2.f}, {-4.f, -3.f}};
|
||||
for(const auto& p : pts)
|
||||
{
|
||||
const float got = ck::swiglu_oai(p[0], p[1]);
|
||||
const float ref = static_cast<float>(ref_swiglu_oai(p[0], p[1]));
|
||||
EXPECT_NEAR(got, ref, kTol) << "gate=" << p[0] << " up=" << p[1];
|
||||
}
|
||||
}
|
||||
|
||||
// gate is upper-bounded to limit (7); it is NOT lower-clamped.
|
||||
TEST(SwigluOai, GateUpperClamp)
|
||||
{
|
||||
EXPECT_NEAR(ck::swiglu_oai(100.f, 1.f), static_cast<float>(ref_swiglu_oai(7.0, 1.0)), kTol);
|
||||
// gate >= limit all saturate to the same value.
|
||||
EXPECT_NEAR(ck::swiglu_oai(7.f, 1.f), ck::swiglu_oai(100.f, 1.f), kTol);
|
||||
// large negative gate passes through unclamped.
|
||||
EXPECT_NEAR(ck::swiglu_oai(-50.f, 1.f), static_cast<float>(ref_swiglu_oai(-50.0, 1.0)), kTol);
|
||||
}
|
||||
|
||||
// up is symmetric-clamped to [-7, 7].
|
||||
TEST(SwigluOai, UpSymmetricClamp)
|
||||
{
|
||||
EXPECT_NEAR(ck::swiglu_oai(1.f, 100.f), static_cast<float>(ref_swiglu_oai(1.0, 7.0)), kTol);
|
||||
EXPECT_NEAR(ck::swiglu_oai(1.f, -100.f), static_cast<float>(ref_swiglu_oai(1.0, -7.0)), kTol);
|
||||
}
|
||||
|
||||
// The "+1" shift on up is part of the OAI form: up == -1 zeroes the output exactly.
|
||||
TEST(SwigluOai, UpPlusOneShiftZeroesOutput)
|
||||
{
|
||||
EXPECT_FLOAT_EQ(ck::swiglu_oai(2.f, -1.f), 0.f);
|
||||
EXPECT_FLOAT_EQ(ck::swiglu_oai(-3.f, -1.f), 0.f);
|
||||
}
|
||||
|
||||
// alpha defaults to 1.702 (gpt-oss); passing it explicitly must not change the result.
|
||||
TEST(SwigluOai, DefaultAlphaMatchesExplicit)
|
||||
{
|
||||
EXPECT_FLOAT_EQ(ck::swiglu_oai(2.f, 3.f), ck::swiglu_oai(2.f, 3.f, 7.0f, 1.702f));
|
||||
}
|
||||
Reference in New Issue
Block a user