diff --git a/kt-kernel/operators/amx/la/amx.hpp b/kt-kernel/operators/amx/la/amx.hpp index c8ce391..6281ca0 100644 --- a/kt-kernel/operators/amx/la/amx.hpp +++ b/kt-kernel/operators/amx/la/amx.hpp @@ -35,10 +35,10 @@ static inline __m512 exp_avx512(__m512 x) { const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); __m512 frac_exp = _mm512_fmadd_ps( - frac_part, poly_6, - _mm512_fmadd_ps(frac_part, poly_5, - _mm512_fmadd_ps(frac_part, poly_4, - _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4), + frac_part, poly_3), + frac_part, poly_2), + frac_part, poly_1); __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); return _mm512_mul_ps(two_pow_i, frac_exp); diff --git a/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp b/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp index 81df642..cf3b2f9 100644 --- a/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp +++ b/kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp @@ -51,10 +51,10 @@ static inline __m512 exp_avx512(__m512 x) { const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); __m512 frac_exp = _mm512_fmadd_ps( - frac_part, poly_6, - _mm512_fmadd_ps(frac_part, poly_5, - _mm512_fmadd_ps(frac_part, poly_4, - _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4), + frac_part, poly_3), + frac_part, poly_2), + frac_part, poly_1); __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); return _mm512_mul_ps(two_pow_i, frac_exp);