From b27de4068ba9eb92573da48fbd1347d5d99dda96 Mon Sep 17 00:00:00 2001 From: mrhaoxx Date: Tue, 20 Jan 2026 11:07:22 +0800 Subject: [PATCH] [fix]: fix exp_avx512 for act_fn (#1797) --- kt-kernel/operators/amx/la/amx.hpp | 8 ++++---- kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/kt-kernel/operators/amx/la/amx.hpp b/kt-kernel/operators/amx/la/amx.hpp index c8ce391..6281ca0 100644 --- a/kt-kernel/operators/amx/la/amx.hpp +++ b/kt-kernel/operators/amx/la/amx.hpp @@ -35,10 +35,10 @@ static inline __m512 exp_avx512(__m512 x) { const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); __m512 frac_exp = _mm512_fmadd_ps( - frac_part, poly_6, - _mm512_fmadd_ps(frac_part, poly_5, - _mm512_fmadd_ps(frac_part, poly_4, - _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4), + frac_part, poly_3), + frac_part, poly_2), + frac_part, poly_1); __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); return _mm512_mul_ps(two_pow_i, frac_exp); diff --git a/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp b/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp index 81df642..cf3b2f9 100644 --- a/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp +++ b/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp @@ -51,10 +51,10 @@ static inline __m512 exp_avx512(__m512 x) { const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); __m512 frac_exp = _mm512_fmadd_ps( - frac_part, poly_6, - _mm512_fmadd_ps(frac_part, poly_5, - _mm512_fmadd_ps(frac_part, poly_4, - _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4), + frac_part, poly_3), + frac_part, poly_2), + frac_part, poly_1); __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); return _mm512_mul_ps(two_pow_i, frac_exp);