diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 442fbbf846..7b6fe26dc6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -30,9 +30,28 @@ enum Activation { gelu_and_mul = 0, silu_and_mul = 1, - swiglustep_and_mul = 2 + swiglustep_and_mul = 2, + swiglu_oai_and_mul = 3 }; +// OAI / gpt-oss SwiGLU activation: gate * sigmoid(alpha * gate) * (up + 1), with a +// pre-activation clamp (gate upper-bounded to limit, up symmetric in [-limit, limit]). +// Defaults limit = 7.0, alpha = 1.702 per gpt-oss. Same math as ck_tile::moe::Swiglu, but +// the sigmoid here uses a plain fp32 division (not __builtin_amdgcn_rcpf), so results match +// to ~1e-6 rather than bit-exact. Single source of truth shared by the MoE kernel epilogue +// (swiglu_oai_and_mul) and its host unit test. +// Not constexpr: it always calls math::exp (non-constexpr), so it can never be evaluated in a +// constant expression. inline keeps it ODR-safe in this header. The kernel epilogue and the host +// unit test only call it at runtime. +__host__ __device__ inline float +swiglu_oai(float gate, float up, float limit = 7.0f, float alpha = 1.702f) +{ + gate = math::min(gate, limit); // gate <= limit + up = math::min(math::max(up, -limit), limit); // up in [-limit, limit] + const float sig = 1.0f / (1.0f + math::exp(alpha * -gate)); // sigmoid(alpha * gate) + return gate * sig * (up + 1.0f); // OAI form +} + template ; @@ -1506,6 +1525,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { @@ -1579,6 +1618,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { @@ -2011,6 +2062,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { @@ -2084,6 +2155,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< } c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); } + else if constexpr(ActivationOperation == + Activation::swiglu_oai_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglu_oai_activation(gate, up); + } } else { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 5b791d3668..4e8b954afd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1140,6 +1140,9 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + static_assert(ActivationOperation != Activation::swiglu_oai_and_mul, + "gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the " + "non-blockscale gridwise_moe_gemm."); #if defined(__gfx942__) || defined(__gfx950__) constexpr auto b_coherence_flag = NonTemporalLoadB ? AmdBufferCoherenceEnum::WAVE_NT1 @@ -1694,6 +1697,9 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + static_assert(ActivationOperation != Activation::swiglu_oai_and_mul, + "gridwise_moe_gemm_blockscale does not support swiglu_oai_and_mul; use the " + "non-blockscale gridwise_moe_gemm."); #if defined(__gfx942__) || defined(__gfx950__) constexpr auto b_coherence_flag = NonTemporalLoadB ? AmdBufferCoherenceEnum::WAVE_NT1 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 80e6e4709b..93b1c1812a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -350,3 +350,4 @@ add_subdirectory(synchronization) add_subdirectory(gpu_reference) add_subdirectory(util) add_subdirectory(gpu_verification) +add_subdirectory(swiglu_oai_activation) diff --git a/test/swiglu_oai_activation/CMakeLists.txt b/test/swiglu_oai_activation/CMakeLists.txt new file mode 100644 index 0000000000..4b3b41c7de --- /dev/null +++ b/test/swiglu_oai_activation/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_gtest_executable(test_swiglu_oai_activation test_swiglu_oai_activation.cpp) +if(result EQUAL 0) + target_link_libraries(test_swiglu_oai_activation PRIVATE utility) +endif() diff --git a/test/swiglu_oai_activation/test_swiglu_oai_activation.cpp b/test/swiglu_oai_activation/test_swiglu_oai_activation.cpp new file mode 100644 index 0000000000..371787a68d --- /dev/null +++ b/test/swiglu_oai_activation/test_swiglu_oai_activation.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +// ck::swiglu_oai is the single source of truth shared by the XDL 2-stage MoE epilogue +// (Activation::swiglu_oai_and_mul) and this host unit test. +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" + +namespace { + +// Independent fp64 reference for the OAI / gpt-oss SwiGLU activation: +// gate * sigmoid(alpha * gate) * (up + 1) +// with gate clamped to <= limit and up clamped to [-limit, limit]. Computed in double so +// we are not comparing the fp32 implementation under test against a copy of itself. +double ref_swiglu_oai(double gate, double up, double limit = 7.0, double alpha = 1.702) +{ + gate = std::min(gate, limit); + up = std::min(std::max(up, -limit), limit); + const double s = 1.0 / (1.0 + std::exp(alpha * -gate)); + return gate * s * (up + 1.0); +} + +constexpr float kTol = 1e-3f; + +} // namespace + +// Values match the reference when no clamping is active. +TEST(SwigluOai, MatchesReferenceInRange) +{ + const float pts[][2] = {{0.f, 0.f}, {1.f, 2.f}, {-1.5f, 0.5f}, {3.f, -2.f}, {-4.f, -3.f}}; + for(const auto& p : pts) + { + const float got = ck::swiglu_oai(p[0], p[1]); + const float ref = static_cast(ref_swiglu_oai(p[0], p[1])); + EXPECT_NEAR(got, ref, kTol) << "gate=" << p[0] << " up=" << p[1]; + } +} + +// gate is upper-bounded to limit (7); it is NOT lower-clamped. +TEST(SwigluOai, GateUpperClamp) +{ + EXPECT_NEAR(ck::swiglu_oai(100.f, 1.f), static_cast(ref_swiglu_oai(7.0, 1.0)), kTol); + // gate >= limit all saturate to the same value. + EXPECT_NEAR(ck::swiglu_oai(7.f, 1.f), ck::swiglu_oai(100.f, 1.f), kTol); + // large negative gate passes through unclamped. + EXPECT_NEAR(ck::swiglu_oai(-50.f, 1.f), static_cast(ref_swiglu_oai(-50.0, 1.0)), kTol); +} + +// up is symmetric-clamped to [-7, 7]. +TEST(SwigluOai, UpSymmetricClamp) +{ + EXPECT_NEAR(ck::swiglu_oai(1.f, 100.f), static_cast(ref_swiglu_oai(1.0, 7.0)), kTol); + EXPECT_NEAR(ck::swiglu_oai(1.f, -100.f), static_cast(ref_swiglu_oai(1.0, -7.0)), kTol); +} + +// The "+1" shift on up is part of the OAI form: up == -1 zeroes the output exactly. +TEST(SwigluOai, UpPlusOneShiftZeroesOutput) +{ + EXPECT_FLOAT_EQ(ck::swiglu_oai(2.f, -1.f), 0.f); + EXPECT_FLOAT_EQ(ck::swiglu_oai(-3.f, -1.f), 0.f); +} + +// alpha defaults to 1.702 (gpt-oss); passing it explicitly must not change the result. +TEST(SwigluOai, DefaultAlphaMatchesExplicit) +{ + EXPECT_FLOAT_EQ(ck::swiglu_oai(2.f, 3.f), ck::swiglu_oai(2.f, 3.f, 7.0f, 1.702f)); +}