mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 17:09:22 +00:00
Fix standard attention on the CPU
This commit is contained in:
@@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
|
||||
int ny = mm.funcs.size();
|
||||
while (ny > 0 && !mm.funcs[ny-1]) --ny;
|
||||
if (ny >= r2) {
|
||||
int nx64 = Nx/64;
|
||||
int nchunk64 = nx64*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
|
||||
int i02 = ichunk/nx64;
|
||||
int ix = 64*(ichunk - i02*nx64);
|
||||
nchunk = nx32*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
int i02 = ichunk/nx32;
|
||||
int ix = 32*(ichunk - i02*nx32);
|
||||
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
|
||||
}
|
||||
int ix0 = 64*nx64;
|
||||
if (ix0 < Nx) {
|
||||
nx32 -= 2*nx64;
|
||||
nchunk = nx32*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
int i02 = ichunk/nx32;
|
||||
int ix = ix0 + 32*(ichunk - i02*nx32);
|
||||
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
|
||||
}
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
int i02 = ichunk/nx32;
|
||||
@@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
|
||||
}
|
||||
return true;
|
||||
}
|
||||
//if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
|
||||
int gcd = simple_gcd(ne02, nth);
|
||||
int counter = 0;
|
||||
for (int64_t i12 = 0; i12 < ne02; i12++) {
|
||||
@@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
|
||||
}
|
||||
|
||||
if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
|
||||
//printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
|
||||
MulMat mm;
|
||||
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user