mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-04 13:30:47 +00:00
Hopefully this really fixes the confusion between AVX512 and FANCY_SIMD (#216)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -254,6 +254,7 @@ if (GGML_BLAS)
|
||||
endif()
|
||||
|
||||
set (GGML_SOURCES_IQK iqk/iqk_quantize.cpp)
|
||||
set (GGML_HEADERS_IQK iqk/iqk_config.h)
|
||||
if (GGML_IQK_MUL_MAT)
|
||||
message(STATUS "Using optimized iqk matrix multiplications")
|
||||
add_compile_definitions(GGML_USE_IQK_MULMAT)
|
||||
@@ -1324,7 +1325,7 @@ add_library(ggml
|
||||
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
|
||||
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
|
||||
${GGML_SOURCES_IQK_MM} ${GGML_HEADERS_IQK_MM}
|
||||
${GGML_SOURCES_IQK}
|
||||
${GGML_SOURCES_IQK} ${GGML_HEADERS_IQK}
|
||||
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
|
||||
ggml-aarch64.c ggml-aarch64.h
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "iqk/iqk_quantize.h"
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
#include "iqk/iqk_mul_mat.h"
|
||||
#include "iqk/iqk_config.h"
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
@@ -847,11 +848,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float_to_mat = quantize_mat_q8_0,
|
||||
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
// Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2
|
||||
// because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range
|
||||
// (and it gets satured if it does), leading to wrong results.
|
||||
// TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above.
|
||||
.vec_dot_type = GGML_TYPE_Q8_1_X4,
|
||||
#else
|
||||
.vec_dot_type = GGML_TYPE_Q8_0_X4,
|
||||
|
||||
30
ggml/src/iqk/iqk_config.h
Normal file
30
ggml/src/iqk/iqk_config.h
Normal file
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#if defined IQK_IMPLEMENT
|
||||
#undef IQK_IMPLEMENT
|
||||
#endif
|
||||
|
||||
#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD
|
||||
#define IQK_IMPLEMENT
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define IQK_NOINLINE __declspec(noinline)
|
||||
#define IQK_ALWAYS_INLINE inline
|
||||
#if !defined __x86_64__ && defined _M_X64
|
||||
#define __x86_64__
|
||||
#endif
|
||||
#else
|
||||
#define IQK_NOINLINE __attribute__((__noinline__))
|
||||
#define IQK_ALWAYS_INLINE __attribute__((__always_inline__))
|
||||
#endif
|
||||
|
||||
#if defined __x86_64__
|
||||
#if defined HAVE_FANCY_SIMD
|
||||
#undef HAVE_FANCY_SIMD
|
||||
#endif
|
||||
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#define HAVE_FANCY_SIMD
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -7,20 +7,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
#if defined IQK_IMPLEMENT
|
||||
#undef IQK_IMPLEMENT
|
||||
#endif
|
||||
#include "iqk_config.h"
|
||||
|
||||
#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD
|
||||
#define IQK_IMPLEMENT
|
||||
#endif
|
||||
#if defined IQK_IMPLEMENT
|
||||
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#if defined IQK_IMPLEMENT
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
#include "iqk_mul_mat.h"
|
||||
@@ -100,26 +94,6 @@ struct Perf {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define IQK_NOINLINE __declspec(noinline)
|
||||
#define IQK_ALWAYS_INLINE inline
|
||||
#if !defined __x86_64__ && defined _M_X64
|
||||
#define __x86_64__
|
||||
#endif
|
||||
#else
|
||||
#define IQK_NOINLINE __attribute__((__noinline__))
|
||||
#define IQK_ALWAYS_INLINE __attribute__((__always_inline__))
|
||||
#endif
|
||||
|
||||
#if defined __x86_64__
|
||||
#if defined HAVE_FANCY_SIMD
|
||||
#undef HAVE_FANCY_SIMD
|
||||
#endif
|
||||
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#define HAVE_FANCY_SIMD
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
typedef struct {
|
||||
@@ -1472,7 +1446,7 @@ inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
template <typename Q8, typename Bits>
|
||||
inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
if (j == 0) {
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
@@ -1489,7 +1463,7 @@ inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i,
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
@@ -2747,7 +2721,7 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
|
||||
auto h1 = _mm256_andnot_si256(mask4, hbits);
|
||||
auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1);
|
||||
auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2);
|
||||
auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff));
|
||||
auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff;
|
||||
return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)),
|
||||
_mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))),
|
||||
_mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)),
|
||||
@@ -2843,7 +2817,7 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
|
||||
const __m256i values;
|
||||
__m256i data[4];
|
||||
const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001);
|
||||
const __m256i bmask = _mm256_set1_epi16(0xfffe);
|
||||
const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe;
|
||||
const __m128i mask = _mm_set1_epi16(254);
|
||||
const __m128i m127 = _mm_set1_epi16(-127);
|
||||
const __m128i m128 = _mm_set1_epi16(-128);
|
||||
@@ -7049,7 +7023,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
template <typename Bits>
|
||||
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
|
||||
if (j == 0) {
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
|
||||
auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
|
||||
auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
|
||||
@@ -7065,7 +7039,7 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons
|
||||
sumi[1] = _mm256_add_epi32(p2, p4);
|
||||
#endif
|
||||
} else {
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
|
||||
auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
|
||||
auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
|
||||
@@ -7282,7 +7256,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
@@ -7304,7 +7278,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
|
||||
#else
|
||||
@@ -7328,7 +7302,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)),
|
||||
val[1], q8.load_quants(iy, i, 1)),
|
||||
@@ -7349,7 +7323,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
if (i < nb) {
|
||||
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
@@ -7401,7 +7375,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
@@ -7413,7 +7387,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
__m256i acc[2] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare4(i, val);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
|
||||
val[1], q8.load_quants(0, i, 1));
|
||||
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
|
||||
@@ -7436,7 +7410,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare4(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
|
||||
val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
|
||||
@@ -7455,7 +7429,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
if (i < nb) {
|
||||
deq.prepare2(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
|
||||
val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
@@ -8537,7 +8511,7 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase {
|
||||
xv[1] = load1(ix+1, i);
|
||||
xv[2] = load1(ix+2, i);
|
||||
xv[3] = load1(ix+3, i);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
|
||||
@@ -14749,7 +14723,7 @@ struct BaseHelper {
|
||||
};
|
||||
|
||||
struct F16 {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
using Data = __m512;
|
||||
constexpr static int block_size = 16;
|
||||
constexpr static int num_registers = 32;
|
||||
@@ -14910,7 +14884,7 @@ struct HelperQ8KV final : public BaseHelper<step> {
|
||||
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
|
||||
#else
|
||||
auto vd = F16::set1(q8->d);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
|
||||
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
|
||||
#else
|
||||
@@ -14945,7 +14919,7 @@ struct HelperQ80 final : public BaseHelper<step> {
|
||||
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
|
||||
#else
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
|
||||
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
|
||||
#else
|
||||
@@ -15215,7 +15189,7 @@ struct HelperQ40 final : public BaseHelper<step> {
|
||||
#else
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
|
||||
auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
@@ -15260,7 +15234,7 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
|
||||
auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
auto ql = _mm_and_si128(q, mask);
|
||||
auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
|
||||
v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
|
||||
@@ -15306,7 +15280,7 @@ struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
#else
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
|
||||
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask));
|
||||
auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask));
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
@@ -15361,7 +15335,7 @@ struct HelperQ60 final : public BaseHelper<step> {
|
||||
auto bl = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
uint64_t aux64; std::memcpy(&aux64, dl->qh, 8);
|
||||
auto bh = _mm_set_epi64x(aux64, aux64 << 4);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX512F__
|
||||
auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32);
|
||||
auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32);
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
@@ -15537,6 +15511,22 @@ struct FlashMS {
|
||||
}
|
||||
return F16::reduce_max<k_step>(vk);
|
||||
}
|
||||
static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) {
|
||||
auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
|
||||
m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
|
||||
auto m256 = _mm256_cvtepi16_epi32(m128);
|
||||
auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
|
||||
return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
|
||||
}
|
||||
#ifdef __AVX512F__
|
||||
static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) {
|
||||
auto m256 = _mm256_loadu_si256((const __m256i *)mask+l);
|
||||
m256 = _mm256_cmpeq_epi16(m256, _mm256_setzero_si256());
|
||||
auto m512 = _mm512_cvtepi16_epi32(m256);
|
||||
auto mf = _mm512_castsi512_ps(_mm512_or_si512(m512, _mm512_slli_epi32(m512, 16)));
|
||||
return _mm512_or_ps(_mm512_and_ps(mf, val), _mm512_andnot_ps(mf, vinf));
|
||||
}
|
||||
#endif
|
||||
inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto vzero = _mm256_set1_epi16(0);
|
||||
@@ -15554,15 +15544,9 @@ struct FlashMS {
|
||||
}
|
||||
}
|
||||
#else
|
||||
auto vzero = _mm_set1_epi16(0);
|
||||
auto vinf = F16::set1(-INFINITY);
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) {
|
||||
auto m128 = _mm_loadu_si128((const __m128i *)mask + l);
|
||||
m128 = _mm_cmpeq_epi16(m128, vzero);
|
||||
auto m256 = _mm256_cvtepi16_epi32(m128);
|
||||
auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
|
||||
auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l);
|
||||
vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
|
||||
vk[l] = apply_mask(l, mask, F16::load(cache + k_step*j + F16::block_size*l), vinf);
|
||||
}
|
||||
if (softcap <= 0) {
|
||||
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
|
||||
@@ -15630,14 +15614,12 @@ struct FlashQKV {
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]);
|
||||
}
|
||||
}
|
||||
//F16::Data v[8];
|
||||
F16::Data v0, v1;
|
||||
for (int l = 0; l < k_step; l += 4) {
|
||||
auto vs0 = F16::set1(fms.cache[l + 0]);
|
||||
auto vs1 = F16::set1(fms.cache[l + 1]);
|
||||
auto vs2 = F16::set1(fms.cache[l + 2]);
|
||||
auto vs3 = F16::set1(fms.cache[l + 3]);
|
||||
//auto vs = F16::set4(fms.cache + l);
|
||||
for (int i = 0; i < D/F16::block_size; i += 2) {
|
||||
vh.load(l+0, i, v0, v1);
|
||||
vq[i+0] = F16::fmadd(vq[i+0], v0, vs0);
|
||||
@@ -15651,14 +15633,6 @@ struct FlashQKV {
|
||||
vh.load(l+3, i, v0, v1);
|
||||
vq[i+0] = F16::fmadd(vq[i+0], v0, vs3);
|
||||
vq[i+1] = F16::fmadd(vq[i+1], v1, vs3);
|
||||
//vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs);
|
||||
//vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs);
|
||||
//vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs);
|
||||
//vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs);
|
||||
//vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs);
|
||||
//vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs);
|
||||
//vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs);
|
||||
//vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]);
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
#include "iqk_quantize.h"
|
||||
#include "iqk_config.h"
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
@@ -43,15 +44,6 @@ constexpr int popcount(uint32_t x) { return __builtin_popcount(x); }
|
||||
constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); }
|
||||
#endif
|
||||
|
||||
#if defined __x86_64__
|
||||
#if defined HAVE_FANCY_SIMD
|
||||
#undef HAVE_FANCY_SIMD
|
||||
#endif
|
||||
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
#define HAVE_FANCY_SIMD
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
inline int nearest_int(float fval) {
|
||||
|
||||
Reference in New Issue
Block a user