diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 39fb5f83..6da4ab7a 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2958,6 +2958,23 @@ bool create_tensors_helper::create_tensors() { int gqa_ratio = hparams.n_head() / hparams.n_head_kv(); auto cur_splits = model.splits; int adjust_step = std::max(1, int(model.layers.size() / (2*model.splits.size()))); + if (model.max_gpu > 1 && model.max_gpu < int(cur_splits.size())) { + bool equal_split = true; + for (int i = 0; i < int(cur_splits.size()); ++i) { + float p = i > 0 ? cur_splits[i] - cur_splits[i-1] : cur_splits[i]; + if (std::abs(p*cur_splits.size() - 1.f) > 0.02f) { + equal_split = false; break; + } + } + if (equal_split) { + if (cur_splits.size() % model.max_gpu == 0) { + int nadj = cur_splits.size()/model.max_gpu; + adjust_step = (model.layers.size() + nadj - 1) / nadj; + } else { + adjust_step = (model.layers.size() + cur_splits.size() - 1)/cur_splits.size(); + } + } + } for (int il = 0; il < int(model.layers.size()); ++il) { if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) { LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il);