mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Fix unhandled OoM condition when loading GPTQ model with auto split
Free minimum reserved VRAM on previous device when moving to next device
This commit is contained in:
@@ -72,6 +72,8 @@ QMatrix::QMatrix
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
failed = false;
|
||||
|
||||
cuda_q_weight = _q_weight;
|
||||
cuda_q_perm = _q_perm;
|
||||
cuda_q_invperm = _q_invperm;
|
||||
@@ -125,7 +127,14 @@ QMatrix::QMatrix
|
||||
rows_3 = height;
|
||||
rows_2 = height;
|
||||
|
||||
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
|
||||
if (_gptq_g_idx)
|
||||
{
|
||||
if (!make_sequential(_gptq_g_idx))
|
||||
{
|
||||
failed = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shuffle quantized data
|
||||
@@ -527,10 +536,11 @@ __global__ void make_sequential_kernel
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
if (err != cudaSuccess) return false;
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
@@ -604,4 +614,6 @@ void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ public:
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
bool failed;
|
||||
|
||||
QMatrix
|
||||
(
|
||||
const int _device,
|
||||
@@ -62,7 +64,7 @@ public:
|
||||
~QMatrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
bool make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
private:
|
||||
|
||||
|
||||
@@ -233,6 +233,8 @@ uintptr_t make_q_matrix
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
if (m->failed) throw std::runtime_error("CUDA out of memory");
|
||||
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
||||
@@ -277,7 +277,7 @@ class ExLlamaV2:
|
||||
|
||||
assert not self.config.qkv_embed, "Auto GPU split is unsupported when config.qkv_embed = True"
|
||||
|
||||
minimum_reserve_vram = 16 * 1024**2
|
||||
minimum_reserve_vram = 32 * 1024**2
|
||||
last_touched_device = -1
|
||||
current_device = 0
|
||||
num_devices = torch.torch.cuda.device_count()
|
||||
@@ -293,6 +293,7 @@ class ExLlamaV2:
|
||||
reserve_vram = [32 * 1024**2] + [0] * (num_devices - 1)
|
||||
|
||||
reserved_vram_tensors = []
|
||||
minimum_reserve_tensor = None
|
||||
|
||||
# Largest hidden state to ever forward through model
|
||||
|
||||
@@ -334,8 +335,9 @@ class ExLlamaV2:
|
||||
if attn_mask is not None: reserved_vram_tensors.append(attn_mask)
|
||||
attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, None, _torch_device(current_device))
|
||||
|
||||
b = reserve_vram[current_device] + minimum_reserve_vram
|
||||
b = reserve_vram[current_device]
|
||||
reserved_vram_tensors.append(torch.empty((b,), dtype = torch.int8, device = _torch_device(current_device)))
|
||||
minimum_reserve = torch.empty((minimum_reserve_vram,), dtype = torch.int8, device = _torch_device(current_device))
|
||||
|
||||
last_touched_device = current_device
|
||||
|
||||
@@ -375,6 +377,10 @@ class ExLlamaV2:
|
||||
|
||||
module.unload()
|
||||
hidden_state = None
|
||||
|
||||
if minimum_reserve_tensor is not None: del minimum_reserve_tensor
|
||||
minimum_reserve_tensor = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
hidden_state = hidden_state_backup.clone()
|
||||
@@ -469,7 +475,7 @@ class ExLlamaV2:
|
||||
|
||||
for i in range(len(past_len[1])):
|
||||
|
||||
attn_mask = torch.zeros(1, 1, seq_len, past_len[1][i] + seq_len, dtype = torch.float16, device = device)
|
||||
attn_mask = torch.zeros((1, 1, seq_len, past_len[1][i] + seq_len), dtype = torch.float16, device = device)
|
||||
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), -65504.))
|
||||
attn_mask[:, :, : seq_len - 1, past_len[1][i] + 1: past_len[1][i] + seq_len] = attn_mask_triu
|
||||
|
||||
@@ -485,7 +491,7 @@ class ExLlamaV2:
|
||||
|
||||
else:
|
||||
|
||||
attn_mask = torch.zeros(batch_size, 1, seq_len, past_len + seq_len, dtype = torch.float16, device = device)
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device)
|
||||
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), -65504.))
|
||||
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import time
|
||||
|
||||
# Initialize model and cache
|
||||
|
||||
model_directory = "/mnt/str/models/_exl2/llama2-70b-exl2/2.5bpw/"
|
||||
model_directory = "/mnt/str/models/_gptq/TheBloke_Spicyboros-70B-2.2-GPTQ/"
|
||||
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_directory
|
||||
@@ -34,7 +34,7 @@ print("Loading model: " + model_directory)
|
||||
def progress_rep(module, num_modules):
|
||||
yield f"Progress: {100 * module / num_modules:.2f}%"
|
||||
|
||||
cache = ExLlamaV2Cache_8bit(model, lazy = True)
|
||||
cache = ExLlamaV2Cache(model, lazy = True)
|
||||
|
||||
f = model.load_autosplit_gen(cache, last_id_only = True, callback_gen = progress_rep)
|
||||
for item in f:
|
||||
|
||||
Reference in New Issue
Block a user