mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
Fix SER (CPU) (#415)
* Fixing SER bugs * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -12472,6 +12472,11 @@ static void ggml_compute_forward_sum_rows_f32(
|
||||
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
float row_sum = 0;
|
||||
ggml_vec_sum_f32(ne00, &row_sum, src_row);
|
||||
if (!isfinite(row_sum)) {
|
||||
fprintf(stderr, "Oops(%s, %s): found %g for i1 = %d, i2 = %d, i3 = %d. ne00 = %d\n", __func__, dst->name,
|
||||
(double)row_sum, (int)i1, (int)i2, (int)i3, (int)ne00);
|
||||
exit(1);
|
||||
}
|
||||
dst_row[0] = row_sum;
|
||||
}
|
||||
}
|
||||
@@ -14759,6 +14764,18 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
||||
|
||||
GGML_ASSERT(ids->ne[1] == dst->ne[2]);
|
||||
for (int64_t iid1 = ith; iid1 < ids->ne[1]; iid1 += nth) {
|
||||
for (int id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
if (i02 < 0 || i02 >= n_as) {
|
||||
// This is needed for SER. If fewer experts have been activated for this row, we need to
|
||||
// clear it, else there could be garbage that leads to NaNs later on.
|
||||
memset((char *)dst->data + id*dst->nb[1] + iid1*dst->nb[2], 0, dst->ne[0]*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
// initialize matrix_row_counts
|
||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
@@ -15012,6 +15029,18 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
||||
|
||||
GGML_ASSERT(ids->ne[1] == dst->ne[2]);
|
||||
for (int64_t iid1 = ith; iid1 < ids->ne[1]; iid1 += nth) {
|
||||
for (int id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
if (i02 < 0 || i02 >= n_as) {
|
||||
// This is needed for SER. If fewer experts have been activated for this row, we need to
|
||||
// clear it, else there could be garbage that leads to NaNs later on.
|
||||
memset((char *)dst->data + id*dst->nb[1] + iid1*dst->nb[2], 0, dst->ne[0]*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
// initialize matrix_row_counts
|
||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
@@ -15916,7 +15945,7 @@ static void ggml_compute_forward_get_rows_f16(
|
||||
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
} else {
|
||||
memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
|
||||
memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -15960,7 +15989,7 @@ static void ggml_compute_forward_get_rows_bf16(
|
||||
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
} else {
|
||||
memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
|
||||
memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,31 +458,29 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
|
||||
if (r2 <= 8) {
|
||||
MulMat mm;
|
||||
if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
|
||||
int nx64 = Nx/64;
|
||||
int nchunk64 = nx64*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
|
||||
int i02 = ichunk/nx64;
|
||||
int ix = 64*(ichunk - i02*nx64);
|
||||
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
|
||||
}
|
||||
int ix0 = 64*nx64;
|
||||
if (ix0 < Nx) {
|
||||
nx32 -= 2*nx64;
|
||||
nchunk = nx32*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
int i02 = ichunk/nx32;
|
||||
int ix = ix0 + 32*(ichunk - i02*nx32);
|
||||
int ny = mm.funcs.size();
|
||||
while (ny > 0 && !mm.funcs[ny-1]) --ny;
|
||||
if (ny >= r2) {
|
||||
int nx64 = Nx/64;
|
||||
int nchunk64 = nx64*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
|
||||
int i02 = ichunk/nx64;
|
||||
int ix = 64*(ichunk - i02*nx64);
|
||||
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
|
||||
}
|
||||
int ix0 = 64*nx64;
|
||||
if (ix0 < Nx) {
|
||||
nx32 -= 2*nx64;
|
||||
nchunk = nx32*ne02;
|
||||
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
int i02 = ichunk/nx32;
|
||||
int ix = ix0 + 32*(ichunk - i02*nx32);
|
||||
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
|
||||
}
|
||||
}
|
||||
}
|
||||
//for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
// int i02 = ichunk/nx32;
|
||||
// int ix = 32*(ichunk - i02*nx32);
|
||||
// DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
|
||||
// mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
|
||||
//}
|
||||
return true;
|
||||
}
|
||||
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
|
||||
|
||||
Reference in New Issue
Block a user