Better strategy for attention matrix multiplications when generating tokens (#218)

* This seems to be a better way

to do the attention matrix multiplications in the TG case.

* Cleanup

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-02-22 09:38:51 +02:00
committed by GitHub
parent 17d43879c6
commit dcff697474

View File

@@ -352,6 +352,26 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
return true;
}
if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
//printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;
}
int n_per_thread = (Nx + nth - 1)/nth;
int first = ith*n_per_thread;
if (first >= Nx) return true;
int last = first + n_per_thread <= Nx ? first + n_per_thread : Nx;
for (int ix = first; ix < last; ++ix) {
for (int i02 = 0; i02 < ne02; ++i02) {
DataInfo info{C + ix + i02*nb2, (const char *)B + i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
mm.funcs[0](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), nb02, info, 1);
}
}
return true;
}
int gcd = simple_gcd(ne12*ne13, nth);
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
@@ -6229,8 +6249,8 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
GGML_ASSERT(nrc_x%8 == 0);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
@@ -6298,8 +6318,42 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
if (nrc_y == 1 && nrc_x == 1) {
auto dx = (const float *)vx;
auto dy = (const float *)info.src1_row(0);
#ifdef HAVE_FANCY_SIMD
auto sy = (const int32_t *)(dy + 1);
auto x = (const int8_t *)(dx + 2);
auto y = (const int8_t *)(dy + 2);
auto isum = _mm512_setzero_si512();
for (int i = 0; i < n/64; ++i) {
auto qx = _mm512_loadu_si512((const __m512i *)x + i);
auto qy = _mm512_loadu_si512((const __m512i *)y + i);
isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
}
auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
for (int i = 2*(n/64); i < n/32; ++i) {
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
}
info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
#else
auto x = (const int8_t *)(dx + 2);
auto y = (const int8_t *)(dy + 2);
auto isum = _mm256_setzero_si256();
for (int i = 0; i < n/32; ++i) {
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
}
info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
#endif
return;
}
GGML_ASSERT(nrc_x%8 == 0);
__m256i qx[2];
__m256i acc[2*nrc_y] = {};
float dy[nrc_y];