mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Refactor iqk: FA compiles
If it works is a different story. Current compile time: 107.3 sesonds on the Ryzen-7950X
This commit is contained in:
@@ -566,6 +566,67 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t,
|
||||
|
||||
}
|
||||
|
||||
void iqk_gemm_default_floats(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) {
|
||||
using q_float = float;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
constexpr int nrc_q = 8;
|
||||
constexpr int nrc_k = 8;
|
||||
#else
|
||||
// somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 8;
|
||||
#endif
|
||||
GGML_ASSERT(k_step%nrc_k == 0);
|
||||
int qrem = nq - nrc_q*(nq/nrc_q);
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<float, nrc_q>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if (qrem > 0) {
|
||||
switch (qrem) {
|
||||
case 1: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 3: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case 4: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 5: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 6: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 7: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// ----------------------------------- __aarch64__ -----------------------------------------------
|
||||
|
||||
|
||||
@@ -8,4 +8,6 @@
|
||||
|
||||
bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
|
||||
void iqk_gemm_default_floats(int D, int nq, const char * vx, size_t bx, DataInfo& info, int k_step);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -2849,4 +2849,272 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef __AVX2__
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
if (nrc_y == 1 && nrc_x == 1) {
|
||||
auto dx = (const float *)vx;
|
||||
auto dy = (const float *)info.src1_row(0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sy = (const int32_t *)(dy + 1);
|
||||
auto x = (const int8_t *)(dx + 2);
|
||||
auto y = (const int8_t *)(dy + 2);
|
||||
auto isum = _mm512_setzero_si512();
|
||||
for (int i = 0; i < n/64; ++i) {
|
||||
auto qx = _mm512_loadu_si512((const __m512i *)x + i);
|
||||
auto qy = _mm512_loadu_si512((const __m512i *)y + i);
|
||||
isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
|
||||
}
|
||||
auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
|
||||
for (int i = 2*(n/64); i < n/32; ++i) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
|
||||
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
|
||||
isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
|
||||
}
|
||||
info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
|
||||
#else
|
||||
auto x = (const int8_t *)(dx + 2);
|
||||
auto y = (const int8_t *)(dy + 2);
|
||||
auto isum = _mm256_setzero_si256();
|
||||
for (int i = 0; i < n/32; ++i) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
|
||||
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
|
||||
auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
|
||||
isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
|
||||
}
|
||||
info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
__m256i qx[2];
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
int32_t sy[nrc_y];
|
||||
#else
|
||||
__m256i sx[2];
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iptr = (const int32_t *)(dptr+1);
|
||||
sy[iy] = -127*iptr[0];
|
||||
#endif
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
auto dx = (const float *)((const char *)vx + ix*bx);
|
||||
auto q8x = (const int8_t *)(dx + 2);
|
||||
for (int i = 0; i < n/64; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
|
||||
#else
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
|
||||
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
#endif
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
|
||||
#else
|
||||
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
|
||||
acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
if (int i = 2*(n/64); i < n/32) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
|
||||
#else
|
||||
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
|
||||
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
|
||||
#else
|
||||
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
|
||||
acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
|
||||
#else
|
||||
info.store(ix, iy, dx[0]*dy[iy]*sumi);
|
||||
#endif
|
||||
acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
__m512i qx[4];
|
||||
__m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
int32_t sy[nrc_y];
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
auto iptr = (const int32_t *)(dptr + 1);
|
||||
sy[iy] = -64*iptr[0];
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
const int8_t * q8x[8];
|
||||
float dx[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int kx = 0; kx < 8; ++kx) {
|
||||
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
|
||||
dx[kx] = dptr[0];
|
||||
q8x[kx] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int i = 0; i < n/32; ++i) {
|
||||
for (int kx = 0; kx < 4; ++kx) {
|
||||
qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
|
||||
_mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
|
||||
}
|
||||
auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
|
||||
auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
|
||||
auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
|
||||
auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
|
||||
qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
|
||||
qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
|
||||
qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
|
||||
qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
|
||||
if constexpr (nrc_y <= 4) {
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
} else {
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
}
|
||||
}
|
||||
auto scales_x = _mm256_loadu_ps(dx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
if constexpr (nrc_y <= 4) {
|
||||
auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
|
||||
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
|
||||
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
|
||||
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
|
||||
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
|
||||
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
|
||||
} else {
|
||||
acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
|
||||
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
|
||||
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
|
||||
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
|
||||
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
|
||||
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
|
||||
acc[iy] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <int k_step>
|
||||
inline std::pair<mul_mat_t, int> mul_mat_kernel(int D, int int_typeA, int nq) {
|
||||
auto typeA = ggml_type(int_typeA);
|
||||
constexpr int kMaxQ = 8;
|
||||
#define MAKE_FUNCS(mul_mat, n) \
|
||||
if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
|
||||
else {\
|
||||
switch (n) {\
|
||||
case 1: return std::make_pair(mul_mat, 1>, 1);\
|
||||
case 2: return std::make_pair(mul_mat, 2>, 2);\
|
||||
case 3: return std::make_pair(mul_mat, 3>, 3);\
|
||||
case 4: return std::make_pair(mul_mat, 4>, 4);\
|
||||
case 5: return std::make_pair(mul_mat, 5>, 5);\
|
||||
case 6: return std::make_pair(mul_mat, 6>, 6);\
|
||||
case 7: return std::make_pair(mul_mat, 7>, 7);\
|
||||
}\
|
||||
}
|
||||
#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
|
||||
if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
|
||||
else {\
|
||||
switch (n) {\
|
||||
case 1: return std::make_pair(mul_mat<1>, 1);\
|
||||
case 2: return std::make_pair(mul_mat<2>, 2);\
|
||||
case 3: return std::make_pair(mul_mat<3>, 3);\
|
||||
case 4: return std::make_pair(mul_mat<4>, 4);\
|
||||
case 5: return std::make_pair(mul_mat<5>, 5);\
|
||||
case 6: return std::make_pair(mul_mat<6>, 6);\
|
||||
case 7: return std::make_pair(mul_mat<7>, 7);\
|
||||
}\
|
||||
}
|
||||
if (typeA == GGML_TYPE_Q8_KV) {
|
||||
#ifdef __aarch64__
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
|
||||
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (D%32 == 0 && k_step%8 == 0) {
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
|
||||
} else {
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
|
||||
}
|
||||
#endif
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
|
||||
#endif
|
||||
}
|
||||
else if (typeA == GGML_TYPE_Q8_KV_R8) {
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
|
||||
}
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
|
||||
inline std::pair<mul_mat_t, int> mul_mat_kernel(int D, int int_typeA, int nq, int k_step) {
|
||||
switch (k_step) {
|
||||
case 32: return mul_mat_kernel< 32>(D, int_typeA, nq);
|
||||
case 64: return mul_mat_kernel< 64>(D, int_typeA, nq);
|
||||
case 128: return mul_mat_kernel<128>(D, int_typeA, nq);
|
||||
default: GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) {
|
||||
auto [mul_mat, nrc_q] = mul_mat_kernel(D, type_k, nq, k_step);
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
mul_mat(D, k, stride_k, info, k_step);
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
int iq = nrc_q*(nq/nrc_q);
|
||||
if (iq < nq) {
|
||||
auto [mul_mat1, nrc_q1] = mul_mat_kernel(D, type_k, nq - iq, k_step);
|
||||
GGML_ASSERT(nrc_q1 == nq - iq);
|
||||
mul_mat1(D, k, stride_k, info, k_step);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -8,4 +8,6 @@
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -2635,4 +2635,129 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
|
||||
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <int k_step>
|
||||
inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
|
||||
auto typeA = ggml_type(int_typeA);
|
||||
constexpr int kMaxQ = 8;
|
||||
#define MAKE_FUNCS(mul_mat, n) \
|
||||
if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
|
||||
else {\
|
||||
switch (n) {\
|
||||
case 1: return std::make_pair(mul_mat, 1>, 1);\
|
||||
case 2: return std::make_pair(mul_mat, 2>, 2);\
|
||||
case 3: return std::make_pair(mul_mat, 3>, 3);\
|
||||
case 4: return std::make_pair(mul_mat, 4>, 4);\
|
||||
case 5: return std::make_pair(mul_mat, 5>, 5);\
|
||||
case 6: return std::make_pair(mul_mat, 6>, 6);\
|
||||
case 7: return std::make_pair(mul_mat, 7>, 7);\
|
||||
}\
|
||||
}
|
||||
#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
|
||||
if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
|
||||
else {\
|
||||
switch (n) {\
|
||||
case 1: return std::make_pair(mul_mat<1>, 1);\
|
||||
case 2: return std::make_pair(mul_mat<2>, 2);\
|
||||
case 3: return std::make_pair(mul_mat<3>, 3);\
|
||||
case 4: return std::make_pair(mul_mat<4>, 4);\
|
||||
case 5: return std::make_pair(mul_mat<5>, 5);\
|
||||
case 6: return std::make_pair(mul_mat<6>, 6);\
|
||||
case 7: return std::make_pair(mul_mat<7>, 7);\
|
||||
}\
|
||||
}
|
||||
if (typeA == GGML_TYPE_Q8_0) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
else if (typeA == GGML_TYPE_Q8_0_R8) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
|
||||
#else
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_2, nq);
|
||||
#endif
|
||||
}
|
||||
else if (typeA == GGML_TYPE_Q8_0) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
else if (typeA == GGML_TYPE_Q4_0) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
else if (typeA == GGML_TYPE_Q4_1) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
|
||||
#else
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
else if (typeA == GGML_TYPE_IQ4_NL) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<IQ4_NL_Unpacker, nq);
|
||||
#else
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<IQ4_NL_Unpacker, nq);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
return std::make_pair<mul_mat_t, int>(nullptr, 0);
|
||||
}
|
||||
|
||||
inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq, int k_step) {
|
||||
switch (k_step) {
|
||||
case 32: return mul_mat_kernel< 32>(int_typeA, nq);
|
||||
case 64: return mul_mat_kernel< 64>(int_typeA, nq);
|
||||
case 128: return mul_mat_kernel<128>(int_typeA, nq);
|
||||
default: GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) {
|
||||
auto [mul_mat, nrc_q] = mul_mat_kernel(type_k, nq, k_step);
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
mul_mat(D, k, stride_k, info, k_step);
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
int iq = nrc_q*(nq/nrc_q);
|
||||
if (iq < nq) {
|
||||
auto [mul_mat1, nrc_q1] = mul_mat_kernel(type_k, nq - iq, k_step);
|
||||
GGML_ASSERT(nrc_q1 == nq - iq);
|
||||
mul_mat1(D, k, stride_k, info, k_step);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,7 +5,10 @@
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
void iqk_gemm_legacy_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -487,191 +487,6 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
||||
|
||||
namespace {
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
if (nrc_y == 1 && nrc_x == 1) {
|
||||
auto dx = (const float *)vx;
|
||||
auto dy = (const float *)info.src1_row(0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sy = (const int32_t *)(dy + 1);
|
||||
auto x = (const int8_t *)(dx + 2);
|
||||
auto y = (const int8_t *)(dy + 2);
|
||||
auto isum = _mm512_setzero_si512();
|
||||
for (int i = 0; i < n/64; ++i) {
|
||||
auto qx = _mm512_loadu_si512((const __m512i *)x + i);
|
||||
auto qy = _mm512_loadu_si512((const __m512i *)y + i);
|
||||
isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
|
||||
}
|
||||
auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
|
||||
for (int i = 2*(n/64); i < n/32; ++i) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
|
||||
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
|
||||
isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
|
||||
}
|
||||
info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
|
||||
#else
|
||||
auto x = (const int8_t *)(dx + 2);
|
||||
auto y = (const int8_t *)(dy + 2);
|
||||
auto isum = _mm256_setzero_si256();
|
||||
for (int i = 0; i < n/32; ++i) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
|
||||
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
|
||||
auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
|
||||
isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
|
||||
}
|
||||
info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
__m256i qx[2];
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
int32_t sy[nrc_y];
|
||||
#else
|
||||
__m256i sx[2];
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto iptr = (const int32_t *)(dptr+1);
|
||||
sy[iy] = -127*iptr[0];
|
||||
#endif
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
auto dx = (const float *)((const char *)vx + ix*bx);
|
||||
auto q8x = (const int8_t *)(dx + 2);
|
||||
for (int i = 0; i < n/64; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
|
||||
#else
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
|
||||
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
#endif
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
|
||||
#else
|
||||
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
|
||||
acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
if (int i = 2*(n/64); i < n/32) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
|
||||
#else
|
||||
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
|
||||
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
#endif
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
|
||||
#else
|
||||
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
|
||||
acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
|
||||
#else
|
||||
info.store(ix, iy, dx[0]*dy[iy]*sumi);
|
||||
#endif
|
||||
acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
__m512i qx[4];
|
||||
__m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
int32_t sy[nrc_y];
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
auto iptr = (const int32_t *)(dptr + 1);
|
||||
sy[iy] = -64*iptr[0];
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
const int8_t * q8x[8];
|
||||
float dx[8];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int kx = 0; kx < 8; ++kx) {
|
||||
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
|
||||
dx[kx] = dptr[0];
|
||||
q8x[kx] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int i = 0; i < n/32; ++i) {
|
||||
for (int kx = 0; kx < 4; ++kx) {
|
||||
qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
|
||||
_mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
|
||||
}
|
||||
auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
|
||||
auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
|
||||
auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
|
||||
auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
|
||||
qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
|
||||
qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
|
||||
qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
|
||||
qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
|
||||
if constexpr (nrc_y <= 4) {
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
} else {
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
}
|
||||
}
|
||||
auto scales_x = _mm256_loadu_ps(dx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
if constexpr (nrc_y <= 4) {
|
||||
auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
|
||||
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
|
||||
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
|
||||
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
|
||||
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
|
||||
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
|
||||
} else {
|
||||
acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
|
||||
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
|
||||
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
|
||||
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
|
||||
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
|
||||
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
|
||||
acc[iy] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
|
||||
(void)Ny;
|
||||
@@ -1281,6 +1096,7 @@ template <int D, int step>
|
||||
struct HelperQ8KV final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
using block_q8 = block_q8_KV<D>;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_KV;
|
||||
constexpr static int block_size_q = D;
|
||||
HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
@@ -1308,6 +1124,7 @@ struct HelperQ8KV final : public BaseHelper<step> {
|
||||
template <int D, int step>
|
||||
struct HelperQ80 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_0;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
using block_q8 = block_q8_2;
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
@@ -1491,6 +1308,7 @@ namespace {
|
||||
template <int D, int step>
|
||||
struct HelperQ80R8 : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_0_R8;
|
||||
#ifdef __AVX2__
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
using block_q8 = block_q8_2;
|
||||
@@ -1613,6 +1431,7 @@ struct HelperQ80R8 : public BaseHelper<step> {
|
||||
template <int D, int step>
|
||||
struct HelperQ8KVR8 : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q8_KV_R8;
|
||||
constexpr static int block_size_q = D;
|
||||
using block_q8 = block_q8_KV<D>;
|
||||
|
||||
@@ -1708,6 +1527,7 @@ struct HelperQ8KVR8 : public BaseHelper<step> {
|
||||
template <int D, int step>
|
||||
struct HelperQ40 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q4_0;
|
||||
#if defined __AVX2__
|
||||
using block_q8 = block_q8_2;
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
@@ -1758,6 +1578,7 @@ template <int D, int step>
|
||||
struct HelperQ41 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
using block_q8 = block_q8_2;
|
||||
constexpr static ggml_type type = GGML_TYPE_Q4_1;
|
||||
constexpr static int block_size_q = QK8_2;
|
||||
HelperQ41(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
@@ -1800,6 +1621,7 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
template <int D, int step>
|
||||
struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
constexpr static ggml_type type = GGML_TYPE_IQ4_NL;
|
||||
#ifdef __aarch64__
|
||||
using block_q8 = block_q8_0;
|
||||
HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {}
|
||||
@@ -1854,6 +1676,7 @@ struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ60 final : public BaseHelper<step> {
|
||||
constexpr static ggml_type type = GGML_TYPE_Q6_0;
|
||||
#ifdef __aarch64__
|
||||
using block_q8 = block_q8_0;
|
||||
constexpr static int block_size_q = QK8_0;
|
||||
@@ -2409,28 +2232,13 @@ struct FlashQKfp32 {
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
constexpr int nrc_q = 8;
|
||||
constexpr int nrc_k = 8;
|
||||
#else
|
||||
// somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 8;
|
||||
#endif
|
||||
constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
static_assert(krem == 0);
|
||||
static_assert(k_step%nrc_k == 0);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if constexpr (qrem > 0) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, qrem>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
}
|
||||
iqk_gemm_default_floats(D, q_step, kh.block, kh.stride, info, k_step);
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
@@ -2473,64 +2281,10 @@ struct FlashQKfp32 {
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
constexpr int nrc_q = 8;
|
||||
constexpr int nrc_k = 8;
|
||||
#else
|
||||
// somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 8;
|
||||
#endif
|
||||
static_assert(k_step%nrc_k == 0);
|
||||
int qrem = nq - nrc_q*(nq/nrc_q);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if (qrem > 0) {
|
||||
switch (qrem) {
|
||||
case 1: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 3: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case 4: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 5: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 6: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 7: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
iqk_gemm_default_floats(D, nq, kh.block, kh.stride, info, k_step);
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
@@ -2603,136 +2357,17 @@ struct FlashQKfp32 {
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename KHelper>
|
||||
static inline std::pair<mul_mat_t, int> mul_mat_kernel(int nq) {
|
||||
constexpr int kMaxQ = 8;
|
||||
#define MAKE_FUNCS(mul_mat, n) \
|
||||
if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
|
||||
else {\
|
||||
switch (n) {\
|
||||
case 1: return std::make_pair(mul_mat, 1>, 1);\
|
||||
case 2: return std::make_pair(mul_mat, 2>, 2);\
|
||||
case 3: return std::make_pair(mul_mat, 3>, 3);\
|
||||
case 4: return std::make_pair(mul_mat, 4>, 4);\
|
||||
case 5: return std::make_pair(mul_mat, 5>, 5);\
|
||||
case 6: return std::make_pair(mul_mat, 6>, 6);\
|
||||
case 7: return std::make_pair(mul_mat, 7>, 7);\
|
||||
}\
|
||||
}
|
||||
#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
|
||||
if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
|
||||
else {\
|
||||
switch (n) {\
|
||||
case 1: return std::make_pair(mul_mat<1>, 1);\
|
||||
case 2: return std::make_pair(mul_mat<2>, 2);\
|
||||
case 3: return std::make_pair(mul_mat<3>, 3);\
|
||||
case 4: return std::make_pair(mul_mat<4>, 4);\
|
||||
case 5: return std::make_pair(mul_mat<5>, 5);\
|
||||
case 6: return std::make_pair(mul_mat<6>, 6);\
|
||||
case 7: return std::make_pair(mul_mat<7>, 7);\
|
||||
}\
|
||||
}
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
|
||||
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (D%32 == 0 && k_step%8 == 0) {
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
|
||||
} else {
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
|
||||
}
|
||||
#endif
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
|
||||
#else
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_2, nq);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
|
||||
#else
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<IQ4_NL_Unpacker, nq);
|
||||
#else
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<IQ4_NL_Unpacker, nq);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
return std::make_pair<mul_mat_t, int>(nullptr, 0);
|
||||
}
|
||||
|
||||
template <typename KHelper, typename block_q8>
|
||||
static inline void mul_mask_kq(const KHelper& kh, int stride_m,
|
||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
constexpr int kMaxQ = 8;
|
||||
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
|
||||
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
|
||||
mul_mat(D, kh.block, kh.stride, info, k_step);
|
||||
info.cur_y += nrc_q;
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
|
||||
iqk_gemm_q8kv_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step);
|
||||
} else {
|
||||
iqk_gemm_legacy_fa(D, q_step, kh.type, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
@@ -2750,17 +2385,13 @@ struct FlashQKfp32 {
|
||||
template <typename KHelper, typename block_q8>
|
||||
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
|
||||
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
|
||||
GGML_ASSERT(nq < q_step);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
mul_mat(D, kh.block, kh.stride, info, k_step);
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
int iq = nrc_q*(nq/nrc_q);
|
||||
if (iq < nq) {
|
||||
auto [mul_mat1, nrc_q1] = mul_mat_kernel<KHelper>(nq - iq);
|
||||
GGML_ASSERT(nrc_q1 == nq - iq);
|
||||
mul_mat1(D, kh.block, kh.stride, info, k_step);
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
|
||||
iqk_gemm_q8kv_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step);
|
||||
} else {
|
||||
iqk_gemm_legacy_fa(D, nq, kh.type, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
|
||||
Reference in New Issue
Block a user