mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Very slightly better
This commit is contained in:
@@ -914,21 +914,17 @@ static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data
|
||||
auto sc16 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)deq.x[i].scales)), shuff);
|
||||
scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16))));
|
||||
scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1))));
|
||||
//scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(deq.x[i].scales+0)))));
|
||||
//scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(deq.x[i].scales+8)))));
|
||||
//auto mins1 = _mm256_mul_ps(scales[0], _mm256_set1_ps(-32.f));
|
||||
//auto mins2 = _mm256_mul_ps(scales[1], _mm256_set1_ps(-32.f));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
|
||||
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
|
||||
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
|
||||
_mm256_storeu_ps(d8 + 8*iy, dy);
|
||||
//auto m4_1 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4))), 16));
|
||||
//auto m4_2 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4))), 16));
|
||||
//auto my1 = _mm256_set_m128(_mm_unpackhi_ps(m4_1, m4_1), _mm_unpacklo_ps(m4_1, m4_1)); // 0,0, 1,1, 2,2, 3,3
|
||||
//auto my2 = _mm256_set_m128(_mm_unpackhi_ps(m4_2, m4_2), _mm_unpacklo_ps(m4_2, m4_2)); // 4,4, 5,5, 6,6, 7,7
|
||||
//accd[iy] = _mm256_fmadd_ps(my1, mins1, accd[iy]);
|
||||
//accd[iy] = _mm256_fmadd_ps(my2, mins2, accd[iy]);
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto dyh = _mm256_extractf128_ps(dy, 1);
|
||||
scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)));
|
||||
scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh));
|
||||
} else {
|
||||
_mm256_storeu_ps(d8 + 8*iy, dy);
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
@@ -966,9 +962,13 @@ static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data
|
||||
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
|
||||
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
|
||||
#endif
|
||||
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
|
||||
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
|
||||
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
|
||||
if constexpr (nrc_y > 1) {
|
||||
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
|
||||
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
|
||||
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
|
||||
} else {
|
||||
accd[iy] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[iy]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user