Second Attempt for #1502

This commit is contained in:
layerdiffusion
2024-08-28 08:07:42 -07:00
parent 37fcb7bce8
commit 0abb6c4686
4 changed files with 102 additions and 48 deletions

View File

@@ -7,7 +7,7 @@ import torch
import platform
from enum import Enum
from backend import stream
from backend import stream, utils
from backend.args import args
@@ -337,9 +337,12 @@ def state_dict_dtype(state_dict):
return major_dtype
def module_size(module, exclude_device=None):
def module_size(module, exclude_device=None, return_split=False):
module_mem = 0
for p in module.parameters():
weight_mem = 0
weight_patterns = ['weight']
for k, p in module.named_parameters():
t = p.data
if exclude_device is not None:
@@ -357,9 +360,66 @@ def module_size(module, exclude_device=None):
element_size = 1.1 # a bit more than 0.5 because of quant state parameters
module_mem += t.nelement() * element_size
if k in weight_patterns:
weight_mem += t.nelement() * element_size
if return_split:
return module_mem, weight_mem, module_mem - weight_mem
return module_mem
def module_move(module, device, recursive=True, excluded_pattens=[]):
if recursive:
return module.to(device=device)
for k, p in module.named_parameters(recurse=False, remove_duplicate=True):
if k in excluded_pattens:
continue
setattr(module, k, utils.tensor2parameter(p.to(device=device)))
return module
def build_module_profile(model, model_gpu_memory_when_using_cpu_swap):
all_modules = []
legacy_modules = []
for m in model.modules():
if hasattr(m, "parameters_manual_cast"):
m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
all_modules.append(m)
elif hasattr(m, "weight"):
m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
legacy_modules.append(m)
gpu_modules = []
gpu_modules_only_extras = []
mem_counter = 0
for m in legacy_modules.copy():
gpu_modules.append(m)
legacy_modules.remove(m)
mem_counter += m.total_mem
for m in sorted(all_modules, key=lambda x: x.extra_mem).copy():
if mem_counter + m.extra_mem < model_gpu_memory_when_using_cpu_swap:
gpu_modules_only_extras.append(m)
all_modules.remove(m)
mem_counter += m.extra_mem
cpu_modules = all_modules
for m in sorted(gpu_modules_only_extras, key=lambda x: x.weight_mem).copy():
if mem_counter + m.weight_mem < model_gpu_memory_when_using_cpu_swap:
gpu_modules.append(m)
gpu_modules_only_extras.remove(m)
mem_counter += m.weight_mem
return gpu_modules, gpu_modules_only_extras, cpu_modules
class LoadedModel:
def __init__(self, model, memory_required):
self.model = model
@@ -392,43 +452,50 @@ class LoadedModel:
raise e
if not do_not_need_cpu_swap:
memory_in_swap = 0
gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
mem_counter = 0
mem_cannot_cast = 0
for m in self.real_model.modules():
if hasattr(m, "parameters_manual_cast"):
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
module_mem = module_size(m)
if mem_counter + module_mem < model_gpu_memory_when_using_cpu_swap:
m.to(self.device)
mem_counter += module_mem
swap_counter = 0
for m in gpu_modules:
m.to(self.device)
mem_counter += m.total_mem
for m in cpu_modules + gpu_modules_only_extras:
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
m.to(self.device) # Quantize happens here
for m in cpu_modules:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
m.to(self.model.offload_device)
if pin_memory:
m._apply(lambda x: x.pin_memory())
swap_counter += m.total_mem
for m in gpu_modules_only_extras:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
if hasattr(m, 'weight') and m.weight is not None:
if pin_memory:
m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device).pin_memory())
else:
memory_in_swap += module_mem
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
m.to(self.device) # Quantize happens here
m.to(self.model.offload_device)
if PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device):
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"):
m.to(self.device)
module_mem = module_size(m)
mem_counter += module_mem
mem_cannot_cast += module_mem
if mem_cannot_cast > 0:
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device))
mem_counter += m.extra_mem
swap_counter += m.weight_mem
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
print(f"[Memory Management] Loaded to {swap_flag} Swap: {memory_in_swap / (1024 * 1024):.2f} MB ({method_flag} method)")
print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)")
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
self.model_accelerated = True
global signal_empty_cache
signal_empty_cache = True
if is_intel_xpu() and not args.disable_ipex_hijack:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)

View File

@@ -412,17 +412,8 @@ class ForgeOperationsGGUF(ForgeOperations):
return
def _apply(self, fn, recurse=True):
if self.weight is not None:
self.weight = utils.tensor2parameter(fn(self.weight))
if self.bias is not None:
self.bias = utils.tensor2parameter(fn(self.bias))
for i in range(5):
quant_state_name = f'quant_state_{i}'
quant_state = getattr(self, quant_state_name, None)
if quant_state is not None:
quant_state = fn(quant_state)
quant_state = utils.tensor2parameter(quant_state)
setattr(self, quant_state_name, quant_state)
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
setattr(self, k, utils.tensor2parameter(fn(p)))
return self
def forward(self, x):

View File

@@ -92,10 +92,8 @@ class ForgeLoader4Bit(torch.nn.Module):
self.quant_type = quant_type
def _apply(self, fn, recurse=True):
if self.weight is not None:
self.weight = utils.tensor2parameter(fn(self.weight))
if self.bias is not None:
self.bias = utils.tensor2parameter(fn(self.bias))
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
setattr(self, k, utils.tensor2parameter(fn(p)))
return self
def _save_to_state_dict(self, destination, prefix, keep_vars):

View File

@@ -620,8 +620,6 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
if d.device != x.device:
d = d.to(device=x.device)
x = x.to(cls.computation_dtype)
return x * d
@classmethod