mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-02 20:48:03 +00:00
Better strategy for attention matrix multiplications when generating tokens (#218)
* This seems to be a better way to do the attention matrix multiplications in the TG case. * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -352,6 +352,26 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
|
||||
//printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
|
||||
MulMat mm;
|
||||
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
|
||||
return false;
|
||||
}
|
||||
int n_per_thread = (Nx + nth - 1)/nth;
|
||||
int first = ith*n_per_thread;
|
||||
if (first >= Nx) return true;
|
||||
int last = first + n_per_thread <= Nx ? first + n_per_thread : Nx;
|
||||
for (int ix = first; ix < last; ++ix) {
|
||||
for (int i02 = 0; i02 < ne02; ++i02) {
|
||||
DataInfo info{C + ix + i02*nb2, (const char *)B + i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
mm.funcs[0](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), nb02, info, 1);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int gcd = simple_gcd(ne12*ne13, nth);
|
||||
int counter = 0;
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
@@ -6229,8 +6249,8 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
|
||||
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
@@ -6298,8 +6318,42 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
GGML_ASSERT(n%32 == 0);
|
||||
if (nrc_y == 1 && nrc_x == 1) {
|
||||
auto dx = (const float *)vx;
|
||||
auto dy = (const float *)info.src1_row(0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto sy = (const int32_t *)(dy + 1);
|
||||
auto x = (const int8_t *)(dx + 2);
|
||||
auto y = (const int8_t *)(dy + 2);
|
||||
auto isum = _mm512_setzero_si512();
|
||||
for (int i = 0; i < n/64; ++i) {
|
||||
auto qx = _mm512_loadu_si512((const __m512i *)x + i);
|
||||
auto qy = _mm512_loadu_si512((const __m512i *)y + i);
|
||||
isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
|
||||
}
|
||||
auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
|
||||
for (int i = 2*(n/64); i < n/32; ++i) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
|
||||
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
|
||||
isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
|
||||
}
|
||||
info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
|
||||
#else
|
||||
auto x = (const int8_t *)(dx + 2);
|
||||
auto y = (const int8_t *)(dy + 2);
|
||||
auto isum = _mm256_setzero_si256();
|
||||
for (int i = 0; i < n/32; ++i) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
|
||||
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
|
||||
auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
|
||||
isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
|
||||
}
|
||||
info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
__m256i qx[2];
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
|
||||
Reference in New Issue
Block a user