diff --git a/src/llama-quantize.cpp b/src/llama-quantize.cpp index 68abb7c7..bdd5321e 100644 --- a/src/llama-quantize.cpp +++ b/src/llama-quantize.cpp @@ -1531,6 +1531,45 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const int64_t n_per_row = tensor->ne[0]; const int64_t nrows = tensor->ne[1]; + if (nthread > 1 && (tensor->ne[2] % nthread == 0 || tensor->ne[2] >= 2*nthread)) { + std::mutex mutex; + int counter = 0; + bool valid = true; + auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, tensor, imatrix] () { + int ne2 = tensor->ne[2]; + auto row_size = ggml_row_size(new_type, tensor->ne[0]); + auto matrix_size = row_size * tensor->ne[1]; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + int i02 = counter++; + if (i02 >= ne2) { + if (local_size > 0) { + new_size += local_size; + } + break; + } + lock.unlock(); + auto this_imatrix = imatrix ? imatrix + i02 * tensor->ne[0] : nullptr; + auto this_data = (char *)new_data + i02*matrix_size; + auto this_size = ggml_quantize_chunk(new_type, f32_data + i02*tensor->ne[0]*tensor->ne[1], this_data, 0, tensor->ne[1], tensor->ne[0], this_imatrix); + local_size += this_size; + + // validate the quantized data + if (!ggml_validate_row_data(new_type, this_data, matrix_size)) { + lock.lock(); + valid = false; + break; + } + } + }; + for (int it = 0; it < nthread; ++it) workers.emplace_back(std::thread(compute)); + for (auto & w : workers) w.join(); + workers.clear(); + if (!valid) { + throw std::runtime_error("quantized data validation failed"); + } + } else { static const int64_t min_chunk_size = 32 * 512; const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) * chunk_size_multiplier; @@ -1548,6 +1587,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); } + } } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); }