Better VRAM utilization strategy for split mode graph (#1126)

* Better VRAM utilization strategy for split mode graph

* Fix assert when --max-gpu is less than available GPUs

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2026-01-09 13:36:02 +02:00
committed by GitHub
parent 8725d110d2
commit 08a0da389c

View File

@@ -249,7 +249,8 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
#endif
}
static std::vector<int> create_split(int nr, int granularity, const std::vector<float> & splits, const std::vector<size_t> & mem_used) {
static std::vector<int> create_split(int nr, int granularity, const std::vector<float> & splits, const std::vector<size_t> & mem_used,
bool verbose = false) {
GGML_ASSERT(nr % granularity == 0);
GGML_ASSERT(!splits.empty());
if (granularity < 0) return std::vector<int>(splits.size(), nr);
@@ -260,21 +261,25 @@ static std::vector<int> create_split(int nr, int granularity, const std::vector<
std::vector<int> result(splits.size());
float last_split = 0;
int sum = 0;
if (verbose) printf("--- %s: %d chunks\n", __func__, nchunk);
for (int i = 0; i < (int)splits.size(); ++i) {
float p = splits[i] - last_split;
float p0 = p;
p += (p - 1.f*mem_used[i]/tot_memory_used);
result[i] = roundf(p*nchunk);
if (result[i] < 0) result[i] = 0;
if (verbose) printf("i = %d, p0 = %g, p = %g, result = %d\n", i, p0, p, result[i]);
sum += result[i];
last_split = splits[i];
}
while (sum > nchunk) {
last_split = 0;
float best_err = 0;
float best_err = -INFINITY;
int ibest = -1;
for (int i = 0; i < (int)splits.size(); ++i) {
if (result[i] > 0) {
float p = splits[i] - last_split;
p += (p - 1.f*mem_used[i]/tot_memory_used);
float n_want = p*nchunk;
float err = result[i] - n_want;
if (err > best_err) {
@@ -289,10 +294,11 @@ static std::vector<int> create_split(int nr, int granularity, const std::vector<
}
while (sum < nchunk) {
last_split = 0;
float best_err = 0;
float best_err = -INFINITY;
int ibest = -1;
for (int i = 0; i < (int)splits.size(); ++i) {
float p = splits[i] - last_split;
p += (p - 1.f*mem_used[i]/tot_memory_used);
float n_want = p*nchunk;
float err = n_want - result[i];
if (err > best_err) {
@@ -3034,6 +3040,9 @@ bool create_tensors_helper::create_tensors() {
}
printf("\n");
}
//printf("=== Layer %2d. Mem used so far:", il);
//for (auto mem : mem_used) printf(" %g", mem/1024./1024.);
//printf("\n");
auto & layer = model.layers[il];
auto ctx_split = ctx_for_layer_split(il);
if (layer.attn_norm) {
@@ -3052,8 +3061,12 @@ bool create_tensors_helper::create_tensors() {
if (tt.blck_size > granularity_vo) granularity_vo = tt.blck_size;
GGML_ASSERT(granularity_vo % hparams.n_embd_head_v == 0);
}
auto split_vo = create_split(layer.wo->ne[0], granularity_vo, cur_splits, mem_used);
auto split_kq = create_split(layer.wq->ne[1], granularity_kq, cur_splits, mem_used);
auto split_vo = create_split(layer.wo->ne[0], granularity_vo, cur_splits, mem_used); //, true);
auto split_kq = create_split(layer.wq->ne[1], granularity_kq, cur_splits, mem_used); //, true);
//printf(" split_vo:"); for (auto s : split_vo) printf(" %d", s);
//printf("\n");
//printf(" split_kq:"); for (auto s : split_kq) printf(" %d", s);
//printf("\n");
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split_vo, mem_used);
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split_kq, mem_used);
if (layer.bo) {