mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 09:53:40 +00:00
* This fixes confusion around Q8_0 on AVX2 * This does it for iq4_nl, including FA * This does it for iq4_nl on Zen4, but FA does not work * Slightly more clear * Adding forgotten q8_0_r8 to num_rows() --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
925 lines
40 KiB
C++
925 lines
40 KiB
C++
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
|
|
// vi: set et ft=cpp fenc=utf-8 :vi
|
|
//
|
|
//
|
|
// Copyright (C) 2024 Iwan Kawrakow
|
|
// MIT license
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
|
|
#pragma once
|
|
|
|
#include "iqk_config.h"
|
|
|
|
#if defined IQK_IMPLEMENT
|
|
|
|
#include <cstring>
|
|
#include <type_traits>
|
|
#include <vector>
|
|
#include <cstdint>
|
|
|
|
#include "ggml-impl.h"
|
|
#include "ggml-quants.h"
|
|
#include "iqk_mul_mat.h"
|
|
#include "iqk_quantize.h"
|
|
|
|
#define GGML_COMMON_IMPL_C
|
|
#include "ggml-common.h"
|
|
|
|
#define FA_TIMING 0
|
|
|
|
#include <utility>
|
|
#include <array>
|
|
#if FA_TIMING
|
|
#include <chrono>
|
|
#include <mutex>
|
|
struct Perf {
|
|
using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
|
|
std::array<double, 5> times = {};
|
|
std::mutex mutex;
|
|
bool report;
|
|
static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
|
|
inline void accum(int what, const TimePoint& t1) {
|
|
auto t2 = cur_time();
|
|
auto dt = delta(t1, t2);
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
times[what] += dt;
|
|
}
|
|
inline void accum_nolock(int what, const TimePoint& t1) {
|
|
auto t2 = cur_time();
|
|
auto dt = delta(t1, t2);
|
|
times[what] += dt;
|
|
}
|
|
inline void add(const Perf& other) {
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
|
|
}
|
|
Perf(bool r) : report(r) {}
|
|
~Perf() {
|
|
if (report) {
|
|
double tot = 0;
|
|
for (auto& t : times) tot += t;
|
|
if (!tot) return;
|
|
printf("======================= Timing: %g ms in total\n", tot);
|
|
for (int i = 0; i < int(times.size()); ++i) {
|
|
if (times[i]) {
|
|
printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
|
|
}
|
|
}
|
|
}
|
|
}
|
|
static Perf& instance() {
|
|
static Perf p(true);
|
|
return p;
|
|
}
|
|
static double delta(const TimePoint& t1, const TimePoint& t2) {
|
|
return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
|
}
|
|
};
|
|
#endif
|
|
|
|
#ifdef __AVX2__
|
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
|
#endif
|
|
|
|
typedef struct {
|
|
int32_t i1;
|
|
int32_t i2;
|
|
} mmid_row_mapping;
|
|
|
|
struct DataInfo {
|
|
float * s;
|
|
const char * cy;
|
|
size_t bs;
|
|
size_t by;
|
|
int cur_y = 0;
|
|
int ne11;
|
|
const mmid_row_mapping * row_mapping = nullptr;
|
|
size_t bs2 = 0;
|
|
|
|
inline const char * src1_row(int iy) const {
|
|
if (!row_mapping) return cy + (cur_y + iy)*by;
|
|
int i11 = row_mapping[cur_y + iy].i1 % ne11;
|
|
int i12 = row_mapping[cur_y + iy].i2;
|
|
return cy + (i11 + i12*ne11)*by;
|
|
}
|
|
|
|
inline void store(int ix, int iy, float result) const {
|
|
*(dst_row(iy) + ix) = result;
|
|
}
|
|
#ifdef __AVX__
|
|
inline void store(int ix, int iy, __m128 result) const {
|
|
_mm_storeu_ps(dst_row(iy) + ix, result);
|
|
}
|
|
inline void store(int ix, int iy, __m256 result) const {
|
|
_mm256_storeu_ps(dst_row(iy) + ix, result);
|
|
}
|
|
#endif
|
|
#ifdef __AVX512F__
|
|
inline void store(int ix, int iy, __m512 result) const {
|
|
_mm512_storeu_ps(dst_row(iy) + ix, result);
|
|
}
|
|
#endif
|
|
#ifdef __ARM_NEON
|
|
inline void store(int ix, int iy, float32x4_t result) const {
|
|
vst1q_f32(dst_row(iy) + ix, result);
|
|
}
|
|
#endif
|
|
inline float * dst_row(int iy) const {
|
|
if (!row_mapping) return s + (cur_y + iy)*bs;
|
|
int i12 = row_mapping[cur_y + iy].i2;
|
|
int i1 = row_mapping[cur_y + iy].i1;
|
|
int i2 = i12;
|
|
return s + i1*bs + i2*bs2;
|
|
}
|
|
};
|
|
|
|
typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
|
|
|
|
#define IQK_MAX_NY 8
|
|
|
|
#define IQK_SET_MUL_MAT_FUNCTIONS_T(kernel, Dequantizer, funcs) \
|
|
funcs[0] = kernel<Dequantizer, 1>;\
|
|
funcs[1] = kernel<Dequantizer, 2>;\
|
|
funcs[2] = kernel<Dequantizer, 3>;\
|
|
funcs[3] = kernel<Dequantizer, 4>;\
|
|
funcs[4] = kernel<Dequantizer, 5>;\
|
|
funcs[5] = kernel<Dequantizer, 6>;\
|
|
funcs[6] = kernel<Dequantizer, 7>;\
|
|
funcs[7] = kernel<Dequantizer, 8>;\
|
|
|
|
#define IQK_SET_MUL_MAT_FUNCTIONS_T2(kernel, Dequantizer, Block, funcs) \
|
|
funcs[0] = kernel<Dequantizer, 1, Block>;\
|
|
funcs[1] = kernel<Dequantizer, 2, Block>;\
|
|
funcs[2] = kernel<Dequantizer, 3, Block>;\
|
|
funcs[3] = kernel<Dequantizer, 4, Block>;\
|
|
funcs[4] = kernel<Dequantizer, 5, Block>;\
|
|
funcs[5] = kernel<Dequantizer, 6, Block>;\
|
|
funcs[6] = kernel<Dequantizer, 7, Block>;\
|
|
funcs[7] = kernel<Dequantizer, 8, Block>;\
|
|
|
|
#define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \
|
|
funcs[0] = kernel<1>;\
|
|
funcs[1] = kernel<2>;\
|
|
funcs[2] = kernel<3>;\
|
|
funcs[3] = kernel<4>;\
|
|
funcs[4] = kernel<5>;\
|
|
funcs[5] = kernel<6>;\
|
|
funcs[6] = kernel<7>;\
|
|
funcs[7] = kernel<8>;\
|
|
|
|
|
|
// ==================================================================================================
|
|
|
|
static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
|
|
const uint16_t * scales = (const uint16_t *)scales8;
|
|
const uint32_t a0 = scales[0] | (scales[1] << 16);
|
|
const uint32_t a1 = scales[2] | (scales[3] << 16);
|
|
const uint32_t a2 = scales[4] | (scales[5] << 16);
|
|
aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
|
|
aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
|
|
aux32[2] = a1 & 0x3f3f3f3f;
|
|
aux32[0] = a0 & 0x3f3f3f3f;
|
|
}
|
|
|
|
const uint64_t keven_signs[128] = {
|
|
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
|
|
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
|
|
0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
|
|
0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
|
|
0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
|
|
0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
|
|
0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
|
|
0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
|
|
0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
|
|
0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
|
|
0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
|
|
0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
|
|
0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
|
|
0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
|
|
0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
|
|
0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
|
|
0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
|
|
0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
|
|
0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
|
|
0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
|
|
0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
|
|
0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
|
|
0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
|
|
0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
|
|
0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
|
|
0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
|
|
0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
|
|
0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
|
|
0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
|
|
0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
|
|
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
|
|
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
|
|
};
|
|
|
|
#ifdef __AVX2__
|
|
|
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
|
|
|
static inline float hsum_float_4(__m128 x) {
|
|
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
|
|
x = _mm_add_ss(x, _mm_movehdup_ps(x));
|
|
return _mm_cvtss_f32(x);
|
|
}
|
|
static inline float hsum_float_8(__m256 x) {
|
|
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
|
|
}
|
|
static inline int hsum_i32_8(const __m256i a) {
|
|
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
|
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
|
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
|
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
|
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
|
}
|
|
static inline float hmax_f32_8(__m256 x) {
|
|
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
|
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
|
|
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
|
|
return _mm_cvtss_f32(max4);
|
|
}
|
|
static inline float hmax_float_8(__m256 x) {
|
|
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
|
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
|
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
|
|
return _mm_cvtss_f32(max4);
|
|
}
|
|
|
|
static inline __m128 hsum_float_4x4(__m128 * accm) {
|
|
accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2]));
|
|
accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3]));
|
|
return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1]));
|
|
}
|
|
static inline __m256 hsum_float_8x8(__m256 * accm) {
|
|
for (int i = 0; i < 4; ++i) {
|
|
accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i + 4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i + 4], 0x31));
|
|
//accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
|
|
// _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
|
|
}
|
|
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i + 2]), _mm256_unpackhi_ps(accm[i], accm[i + 2]));
|
|
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
|
|
}
|
|
static inline __m256 hsum_float_4x8(__m256 * accm) {
|
|
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i + 2]), _mm256_unpackhi_ps(accm[i], accm[i + 2]));
|
|
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
|
|
}
|
|
|
|
static inline __m128i load_iq4nl_values_128() {
|
|
static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
|
|
return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
|
|
}
|
|
|
|
static inline __m256i load_iq4nl_values_256() {
|
|
auto val128 = load_iq4nl_values_128();
|
|
return MM256_SET_M128I(val128, val128);
|
|
}
|
|
|
|
#ifdef HAVE_FANCY_SIMD
|
|
static inline __m512i load_iq4nl_values_512() {
|
|
auto val256 = load_iq4nl_values_256();
|
|
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
|
}
|
|
#endif
|
|
|
|
static inline __m128i load_iq4k_values_128() {
|
|
return _mm_loadu_si128((const __m128i *)iq4k_values);
|
|
}
|
|
|
|
static inline __m256i load_iq4k_values_256() {
|
|
auto val128 = load_iq4k_values_128();
|
|
return MM256_SET_M128I(val128, val128);
|
|
}
|
|
|
|
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
|
|
|
constexpr static int nrc_y = nrc;
|
|
|
|
Q8(const DataInfo& info) {
|
|
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
|
|
}
|
|
|
|
#ifdef HAVE_FANCY_SIMD
|
|
inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }
|
|
#endif
|
|
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
|
|
inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }
|
|
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
|
|
|
const block_q8 * y[nrc_y];
|
|
};
|
|
|
|
template <int nrc> struct Q8_16 {
|
|
|
|
constexpr static int nrc_y = nrc;
|
|
|
|
Q8_16(const DataInfo& info) {
|
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
|
auto ptr = (const float *)info.src1_row(iy);
|
|
std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
|
|
y[iy] = (const int8_t *)(ptr + 5);
|
|
}
|
|
}
|
|
|
|
#ifdef HAVE_FANCY_SIMD
|
|
inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
|
|
#endif
|
|
inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
|
|
inline float scale(int iy, int k) const { return d[5*iy+k]; }
|
|
inline float sum_row(int iy) const { return d[5*iy + 4]; }
|
|
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
|
|
|
|
float d[5*nrc_y];
|
|
const int8_t * y[nrc_y];
|
|
};
|
|
|
|
struct Scales8KBase {
|
|
template <typename Q8>
|
|
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
|
|
const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
const __m256i q8s = q8.load_bsums(iy, i);
|
|
const __m256i prod = _mm256_madd_epi16(mins, q8s);
|
|
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
|
|
}
|
|
}
|
|
inline __m256i shuffle(__m128i mins) const {
|
|
return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
|
|
}
|
|
const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
|
|
_mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
|
|
};
|
|
|
|
template <typename Block, bool per_row_scale = false, bool is_f16 = false>
|
|
struct BaseDequantizer {
|
|
BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
|
|
inline void new_row(int ix) {
|
|
if constexpr (per_row_scale) {
|
|
if constexpr (is_f16) {
|
|
const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
|
|
d = GGML_FP16_TO_FP32(*dptr);
|
|
x = (const Block *)(dptr + 1);
|
|
} else {
|
|
const float * dptr = (const float *)((const char *)vx + bx*ix);
|
|
d = *dptr;
|
|
x = (const Block *)(dptr + 1);
|
|
}
|
|
} else {
|
|
x = (const Block *)((const char *)vx + bx*ix);
|
|
}
|
|
}
|
|
|
|
const void * vx;
|
|
const size_t bx;
|
|
const Block * x;
|
|
|
|
float d;
|
|
};
|
|
|
|
template <typename Q8, typename Bits>
|
|
static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
|
if (j == 0) {
|
|
#ifdef HAVE_FANCY_SIMD
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
|
}
|
|
#else
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
|
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
|
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
|
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
|
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
|
|
}
|
|
#endif
|
|
} else {
|
|
#ifdef HAVE_FANCY_SIMD
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
|
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
|
}
|
|
#else
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
|
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
|
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
|
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
|
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
|
|
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
|
|
}
|
|
#endif
|
|
}
|
|
}
|
|
|
|
template <typename Q8, typename Bits>
|
|
static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
|
__m256i p[4];
|
|
if (j == 0) {
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
for (int k = 0; k < 4; ++k) {
|
|
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
|
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
|
|
}
|
|
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
|
|
}
|
|
} else {
|
|
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
|
for (int k = 0; k < 4; ++k) {
|
|
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
|
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
|
|
}
|
|
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
|
|
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
|
|
}
|
|
}
|
|
}
|
|
|
|
#ifdef HAVE_FANCY_SIMD
|
|
|
|
struct BlockPermuter {
|
|
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
|
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
|
};
|
|
|
|
struct Q4Bits {
|
|
inline void prepare(const uint8_t * q4) {
|
|
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
|
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
|
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
|
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
|
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
|
tmp1 = _mm512_and_si512(q4bits, ml);
|
|
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
|
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
|
}
|
|
inline void prepare64(const uint8_t * q4) {
|
|
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
|
values[0] = _mm512_and_si512(q4bits, ml);
|
|
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
|
values[2] = _mm512_and_si512(q4bits, ml);
|
|
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
}
|
|
inline void prepare64a(const uint8_t * q4) {
|
|
for (int k = 0; k < 4; ++k) {
|
|
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
|
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
|
values[k] = _mm512_and_si512(values[k], ml);
|
|
}
|
|
}
|
|
__m512i values[4];
|
|
const __m512i ml = _mm512_set1_epi8(0xf);
|
|
const BlockPermuter perm;
|
|
};
|
|
|
|
struct Q2Bits {
|
|
inline void prepare(const uint8_t * q2) {
|
|
|
|
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
|
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
|
|
|
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
|
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
|
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
|
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
|
values[0] = _mm512_and_si512(values[0], ml);
|
|
values[2] = _mm512_and_si512(values[2], ml);
|
|
}
|
|
__m512i values[4];
|
|
const __m512i ml = _mm512_set1_epi8(0x03);
|
|
BlockPermuter perm;
|
|
};
|
|
|
|
#else
|
|
|
|
struct Q2Bits {
|
|
inline void prepare(const uint8_t * q2, int j) {
|
|
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
|
values[0] = _mm256_and_si256(q2bits, ml);
|
|
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
|
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
|
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
|
}
|
|
__m256i values[4];
|
|
const __m256i ml = _mm256_set1_epi8(0x03);
|
|
};
|
|
|
|
struct Q4Bits {
|
|
inline void prepare(const uint8_t * q4, int j) {
|
|
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
|
values[0] = _mm256_and_si256(q4bits, ml);
|
|
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
|
values[2] = _mm256_and_si256(q4bits, ml);
|
|
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
}
|
|
inline void prepare64(const uint8_t * q4, int j) {
|
|
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
|
values[0] = _mm256_and_si256(q4bits, ml);
|
|
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
|
values[1] = _mm256_and_si256(q4bits, ml);
|
|
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
}
|
|
inline void prepare16(const uint8_t * q4, int j) {
|
|
values[0] = dequant16(q4 + 64*j + 0);
|
|
values[1] = dequant16(q4 + 64*j + 16);
|
|
values[2] = dequant16(q4 + 64*j + 32);
|
|
values[3] = dequant16(q4 + 64*j + 48);
|
|
}
|
|
inline __m256i dequant16(const uint8_t * qs) const {
|
|
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
|
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
|
return _mm256_and_si256(ml, aux256);
|
|
}
|
|
__m256i values[4];
|
|
const __m256i ml = _mm256_set1_epi8(0xf);
|
|
};
|
|
|
|
#endif
|
|
|
|
inline void iqk_transpose_8x8(__m256 * m) {
|
|
for (int k = 0; k < 8; k += 4) {
|
|
auto t0 = _mm256_unpacklo_ps(m[k+0], m[k+1]);
|
|
auto t1 = _mm256_unpacklo_ps(m[k+2], m[k+3]);
|
|
auto t2 = _mm256_unpackhi_ps(m[k+0], m[k+1]);
|
|
auto t3 = _mm256_unpackhi_ps(m[k+2], m[k+3]);
|
|
m[k+0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
|
m[k+1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
|
m[k+2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
|
m[k+3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
|
}
|
|
for (int k = 0; k < 4; ++k) {
|
|
auto t = _mm256_set_m128(_mm256_extractf128_ps(m[k+4], 1), _mm256_extractf128_ps(m[k], 1));
|
|
m[k+0] = _mm256_set_m128(_mm256_castps256_ps128(m[k+4]), _mm256_castps256_ps128(m[k+0]));
|
|
m[k+4] = t;
|
|
}
|
|
}
|
|
|
|
template <int nr = 8>
|
|
static inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
|
|
auto max_i16 = _mm256_setzero_si256();
|
|
__m256i qs[16];
|
|
for (int ib32 = 0; ib32 < 8; ++ib32) {
|
|
qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
|
|
qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
|
|
qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
|
|
qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
|
|
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
|
|
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
|
|
}
|
|
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
|
|
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
|
|
auto max4 = _mm_cvtepi32_ps(imax4);
|
|
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
|
|
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
|
|
bool needs_scaling = true;
|
|
float dnew = _mm_cvtss_f32(max4) * d0;
|
|
if (dnew < 1.f) {
|
|
dnew = 1.f; needs_scaling = false;
|
|
}
|
|
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
|
|
for (int ib32 = 0; ib32 < 8; ++ib32) {
|
|
if (needs_scaling) {
|
|
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
|
|
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
|
|
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
|
|
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
|
|
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
|
|
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
|
|
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
|
|
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
|
|
i0 = _mm256_packs_epi32(i0, i1);
|
|
i2 = _mm256_packs_epi32(i2, i3);
|
|
i0 = _mm256_packs_epi16(i0, i2);
|
|
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
|
|
_mm256_storeu_si256((__m256i *)block, i0);
|
|
} else {
|
|
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
|
|
auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
|
|
auto i0_l = _mm256_castsi256_si128(i0);
|
|
auto i0_h = _mm256_extracti128_si256(i0, 1);
|
|
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
|
|
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
|
|
}
|
|
auto qs = (uint32_t *)q8_k + 8*nr*ib32;
|
|
for (int l = 0; l < 8; ++l) {
|
|
qs[nr*l + k] = block[l];
|
|
}
|
|
}
|
|
return dnew;
|
|
}
|
|
|
|
#else
|
|
// ------------------------------------ __aarch64__ --------------------------------------------------
|
|
|
|
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
|
|
|
constexpr static int nrc_y = nrc;
|
|
|
|
Q8(const DataInfo& info) {
|
|
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
|
|
}
|
|
|
|
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
|
|
inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
|
|
inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
|
|
inline int16x8_t load_bsums8(int iy, int i) const {
|
|
auto q8s = vld1q_s16_x2(y[iy][i].bsums);
|
|
return vpaddq_s16(q8s.val[0], q8s.val[1]);
|
|
}
|
|
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
|
|
|
const block_q8 * y[nrc_y];
|
|
};
|
|
|
|
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
|
|
struct BaseDequantizer {
|
|
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
|
|
inline void new_row(int ix) {
|
|
if constexpr (has_row_scale) {
|
|
if constexpr (scale_is_f16) {
|
|
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
|
|
d = GGML_FP16_TO_FP32(*dptr);
|
|
x = (const block_q *)(dptr + 1);
|
|
} else {
|
|
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
|
d = *dptr;
|
|
x = (const block_q *)(dptr + 1);
|
|
}
|
|
} else {
|
|
x = (const block_q *)((const char *)vx + ix*bx);
|
|
}
|
|
}
|
|
const void * vx;
|
|
const block_q * x;
|
|
const size_t bx;
|
|
const int nrc;
|
|
float d;
|
|
};
|
|
|
|
struct Q4bits {
|
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
uint8x16x4_t b1, b2;
|
|
inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
|
|
b.val[0] = vandq_u8(val[0], m4b);
|
|
b.val[2] = vshrq_n_u8(val[0], 4);
|
|
b.val[1] = vandq_u8(val[1], m4b);
|
|
b.val[3] = vshrq_n_u8(val[1], 4);
|
|
}
|
|
inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
|
|
b.val[0] = vandq_u8(val[0], m4b);
|
|
b.val[1] = vshrq_n_u8(val[0], 4);
|
|
b.val[2] = vandq_u8(val[1], m4b);
|
|
b.val[3] = vshrq_n_u8(val[1], 4);
|
|
}
|
|
inline void prepare(const uint8_t * qs) {
|
|
auto q4bits = vld1q_u8_x2(qs);
|
|
prepare4(b1, q4bits.val);
|
|
q4bits = vld1q_u8_x2(qs+32);
|
|
prepare4(b2, q4bits.val);
|
|
}
|
|
inline void prepare_v2(const uint8_t * qs) {
|
|
auto q4bits = vld1q_u8_x4(qs);
|
|
prepare4(b1, q4bits.val+0);
|
|
prepare4(b2, q4bits.val+2);
|
|
}
|
|
inline void prepare64(const uint8_t * qs) {
|
|
auto q4bits = vld1q_u8_x4(qs);
|
|
b1.val[0] = vandq_u8(q4bits.val[0], m4b);
|
|
b1.val[1] = vandq_u8(q4bits.val[1], m4b);
|
|
b1.val[2] = vandq_u8(q4bits.val[2], m4b);
|
|
b1.val[3] = vandq_u8(q4bits.val[3], m4b);
|
|
b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
|
|
b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
|
|
b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
|
|
b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
|
|
}
|
|
inline void prepare16(const uint8_t * qs) {
|
|
auto q4bits = vld1q_u8_x2(qs);
|
|
prepare4_16(b1, q4bits.val);
|
|
q4bits = vld1q_u8_x2(qs+32);
|
|
prepare4_16(b2, q4bits.val);
|
|
}
|
|
inline void prepare16_v2(const uint8_t * qs) {
|
|
auto q4bits = vld1q_u8_x4(qs);
|
|
prepare4_16(b1, q4bits.val+0);
|
|
prepare4_16(b2, q4bits.val+2);
|
|
}
|
|
};
|
|
|
|
struct Q2bits {
|
|
const uint8x16_t m4b = vdupq_n_u8(0x03);
|
|
uint8x16x4_t b1, b2;
|
|
inline void prepare(const uint8_t * qs) {
|
|
auto q2bits = vld1q_u8_x2(qs);
|
|
b1.val[0] = vandq_u8(q2bits.val[0], m4b);
|
|
b1.val[1] = vandq_u8(q2bits.val[1], m4b);
|
|
|
|
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
|
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
|
b1.val[2] = vandq_u8(q2bits.val[0], m4b);
|
|
b1.val[3] = vandq_u8(q2bits.val[1], m4b);
|
|
|
|
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
|
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
|
b2.val[0] = vandq_u8(q2bits.val[0], m4b);
|
|
b2.val[1] = vandq_u8(q2bits.val[1], m4b);
|
|
|
|
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
|
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
|
b2.val[2] = vandq_u8(q2bits.val[0], m4b);
|
|
b2.val[3] = vandq_u8(q2bits.val[1], m4b);
|
|
}
|
|
};
|
|
|
|
template <typename Q8>
|
|
static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
|
const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
|
auto mzero = vdupq_n_s32(0);
|
|
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
|
auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
|
vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
|
|
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
|
auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
|
vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
|
|
auto p12 = vpaddq_s32(p1, p2);
|
|
|
|
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
|
auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
|
vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
|
|
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
|
auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
|
vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
|
|
auto p34 = vpaddq_s32(p3, p4);
|
|
|
|
auto pall = vpaddq_s32(p12, p34);
|
|
sumi = vmlaq_s32(sumi, scales.val[j], pall);
|
|
}
|
|
|
|
template <typename Q8>
|
|
static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
|
const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
|
|
|
auto mzero = vdupq_n_s32(0);
|
|
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
|
auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
|
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
|
|
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
|
auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
|
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
|
|
auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
|
|
sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
|
|
|
|
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
|
auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
|
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
|
|
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
|
auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
|
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
|
|
auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
|
|
sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
|
|
}
|
|
|
|
struct SignHelper {
|
|
|
|
inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
|
|
|
|
inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
|
|
auto aux = vqtbl1q_u8(signs16, shuffle);
|
|
auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
|
|
b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
|
|
shuffle = vaddq_u8(shuffle, step);
|
|
}
|
|
|
|
const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
|
const uint8x16_t m1 = vdupq_n_u8(1);
|
|
const uint8x16_t step = vdupq_n_u8(2);
|
|
uint8x16_t shuffle;
|
|
};
|
|
|
|
template <typename Dequantizer, int nrc_y>
|
|
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
|
assert(n % QK_K == 0);
|
|
const int nb = n / QK_K;
|
|
|
|
Q8<nrc_y, block_q8_K> q8(info);
|
|
|
|
Dequantizer deq(vx, bx, nrc_y);
|
|
|
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
|
|
|
deq.new_row(ix);
|
|
|
|
float32x4_t acc[nrc_y];
|
|
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
|
|
|
|
for (int i = 0; i < nb; ++i) {
|
|
|
|
int32x4_t sumi[nrc_y];
|
|
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
|
|
|
|
if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
|
|
deq.process_scales(i, q8, acc);
|
|
deq.prepare(i, 0);
|
|
deq.compute(q8, i, 0, sumi);
|
|
deq.prepare(i, 1);
|
|
deq.compute(q8, i, 1, sumi);
|
|
} else {
|
|
if constexpr (Dequantizer::num_blocks() == 8) {
|
|
auto scales = deq.new_block(i, q8, acc);
|
|
deq.prepare(i, 0);
|
|
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
|
deq.prepare(i, 1);
|
|
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
|
}
|
|
else if constexpr (Dequantizer::num_blocks() == 16) {
|
|
auto scales = deq.new_block(i, q8, acc);
|
|
deq.prepare(i, 0);
|
|
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
|
deq.prepare(i, 1);
|
|
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
|
}
|
|
else {
|
|
GGML_ASSERT(false);
|
|
}
|
|
}
|
|
|
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
|
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
|
|
}
|
|
}
|
|
|
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
|
info.store(ix, iy, vaddvq_f32(acc[iy]));
|
|
}
|
|
}
|
|
}
|
|
|
|
static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
|
|
auto sumi = vdupq_n_s32(0);
|
|
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
|
|
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
|
|
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
|
|
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
|
|
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
|
|
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
|
|
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
|
|
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
|
|
return sumi;
|
|
}
|
|
|
|
static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
|
|
int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
|
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
|
|
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
|
|
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
|
|
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
|
|
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
|
|
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
|
|
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
|
|
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
|
|
return sumi;
|
|
}
|
|
|
|
static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
|
|
auto sumi = vdupq_n_s32(0);
|
|
sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
|
|
sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
|
|
sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
|
|
sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
|
|
return sumi;
|
|
}
|
|
|
|
static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
|
|
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
|
|
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
|
|
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
|
|
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
|
|
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
|
|
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
|
|
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
|
|
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
|
|
}
|
|
|
|
static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
|
|
qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
|
|
qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
|
|
qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
|
|
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|