iqk_mul_mat for llama.cpp

This commit is contained in:
Kawrakow
2024-05-27 09:51:08 +02:00
parent 9fa7946997
commit d434b4751a
7 changed files with 2586 additions and 31 deletions

49
ggml.c
View File

@@ -12334,11 +12334,7 @@ UseGgmlGemm1:;
#endif
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
return;
}
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
atomic_store(&state->shared->current_chunk, nth);
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -12346,16 +12342,45 @@ UseGgmlGemm1:;
assert(params->wsize >= ne11*ne12*ne13*row_size);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
int64_t work_size = ne13*ne12*ne11;
int64_t work_per_thread = (work_size + nth - 1)/nth;
int64_t work_start = work_per_thread * ith;
if (work_start >= work_size) {
return;
}
int64_t work_end = MIN(work_size, work_start + work_per_thread);
for (int64_t i_work = work_start; i_work < work_end; ++i_work) {
int64_t i13 = i_work / (ne11*ne12);
int64_t i12 = (i_work - i13*ne11*ne12)/ne11;
int64_t i11 = i_work - i13*ne11*ne12 - i12*ne11;
from_float_to_vec_dot((const float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *)(wdata + i_work*row_size), ne10);
}
}
if (ith == 0) {
atomic_store(&state->shared->current_chunk, nth);
}
//// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
//atomic_store(&state->shared->current_chunk, nth);
//if (src1->type != vec_dot_type) {
// char * wdata = params->wdata;
// const size_t row_size = ggml_row_size(vec_dot_type, ne10);
// assert(params->wsize >= ne11*ne12*ne13*row_size);
// GGML_ASSERT(src1->type == GGML_TYPE_F32);
// for (int64_t i13 = 0; i13 < ne13; ++i13) {
// for (int64_t i12 = 0; i12 < ne12; ++i12) {
// for (int64_t i11 = 0; i11 < ne11; ++i11) {
// from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
// wdata += row_size;
// }
// }
// }
//}
return;
}