From 0abb6c4686efbc7521b2bf8a6f06eba18d0b6ada Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 28 Aug 2024 08:07:42 -0700 Subject: [PATCH] Second Attempt for #1502 --- backend/memory_management.py | 129 +++++++++++++++++++++++-------- backend/operations.py | 13 +--- backend/operations_bnb.py | 6 +- packages_3rdparty/gguf/quants.py | 2 - 4 files changed, 102 insertions(+), 48 deletions(-) diff --git a/backend/memory_management.py b/backend/memory_management.py index a43268a2..f8b57a40 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -7,7 +7,7 @@ import torch import platform from enum import Enum -from backend import stream +from backend import stream, utils from backend.args import args @@ -337,9 +337,12 @@ def state_dict_dtype(state_dict): return major_dtype -def module_size(module, exclude_device=None): +def module_size(module, exclude_device=None, return_split=False): module_mem = 0 - for p in module.parameters(): + weight_mem = 0 + weight_patterns = ['weight'] + + for k, p in module.named_parameters(): t = p.data if exclude_device is not None: @@ -357,9 +360,66 @@ def module_size(module, exclude_device=None): element_size = 1.1 # a bit more than 0.5 because of quant state parameters module_mem += t.nelement() * element_size + + if k in weight_patterns: + weight_mem += t.nelement() * element_size + + if return_split: + return module_mem, weight_mem, module_mem - weight_mem + return module_mem +def module_move(module, device, recursive=True, excluded_pattens=[]): + if recursive: + return module.to(device=device) + + for k, p in module.named_parameters(recurse=False, remove_duplicate=True): + if k in excluded_pattens: + continue + setattr(module, k, utils.tensor2parameter(p.to(device=device))) + + return module + + +def build_module_profile(model, model_gpu_memory_when_using_cpu_swap): + all_modules = [] + legacy_modules = [] + + for m in model.modules(): + if hasattr(m, "parameters_manual_cast"): + m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True) + all_modules.append(m) + elif hasattr(m, "weight"): + m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True) + legacy_modules.append(m) + + gpu_modules = [] + gpu_modules_only_extras = [] + mem_counter = 0 + + for m in legacy_modules.copy(): + gpu_modules.append(m) + legacy_modules.remove(m) + mem_counter += m.total_mem + + for m in sorted(all_modules, key=lambda x: x.extra_mem).copy(): + if mem_counter + m.extra_mem < model_gpu_memory_when_using_cpu_swap: + gpu_modules_only_extras.append(m) + all_modules.remove(m) + mem_counter += m.extra_mem + + cpu_modules = all_modules + + for m in sorted(gpu_modules_only_extras, key=lambda x: x.weight_mem).copy(): + if mem_counter + m.weight_mem < model_gpu_memory_when_using_cpu_swap: + gpu_modules.append(m) + gpu_modules_only_extras.remove(m) + mem_counter += m.weight_mem + + return gpu_modules, gpu_modules_only_extras, cpu_modules + + class LoadedModel: def __init__(self, model, memory_required): self.model = model @@ -392,43 +452,50 @@ class LoadedModel: raise e if not do_not_need_cpu_swap: - memory_in_swap = 0 + gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap) + pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device) + mem_counter = 0 - mem_cannot_cast = 0 - for m in self.real_model.modules(): - if hasattr(m, "parameters_manual_cast"): - m.prev_parameters_manual_cast = m.parameters_manual_cast - m.parameters_manual_cast = True - module_mem = module_size(m) - if mem_counter + module_mem < model_gpu_memory_when_using_cpu_swap: - m.to(self.device) - mem_counter += module_mem + swap_counter = 0 + + for m in gpu_modules: + m.to(self.device) + mem_counter += m.total_mem + + for m in cpu_modules + gpu_modules_only_extras: + if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda': + m.to(self.device) # Quantize happens here + + for m in cpu_modules: + m.prev_parameters_manual_cast = m.parameters_manual_cast + m.parameters_manual_cast = True + m.to(self.model.offload_device) + if pin_memory: + m._apply(lambda x: x.pin_memory()) + swap_counter += m.total_mem + + for m in gpu_modules_only_extras: + m.prev_parameters_manual_cast = m.parameters_manual_cast + m.parameters_manual_cast = True + module_move(m, device=self.device, recursive=False, excluded_pattens=['weight']) + if hasattr(m, 'weight') and m.weight is not None: + if pin_memory: + m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device).pin_memory()) else: - memory_in_swap += module_mem - - if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda': - m.to(self.device) # Quantize happens here - - m.to(self.model.offload_device) - - if PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device): - m._apply(lambda x: x.pin_memory()) - elif hasattr(m, "weight"): - m.to(self.device) - module_mem = module_size(m) - mem_counter += module_mem - mem_cannot_cast += module_mem - - if mem_cannot_cast > 0: - print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB") + m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device)) + mem_counter += m.extra_mem + swap_counter += m.weight_mem swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU' method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked' - print(f"[Memory Management] Loaded to {swap_flag} Swap: {memory_in_swap / (1024 * 1024):.2f} MB ({method_flag} method)") + print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)") print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB") self.model_accelerated = True + global signal_empty_cache + signal_empty_cache = True + if is_intel_xpu() and not args.disable_ipex_hijack: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) diff --git a/backend/operations.py b/backend/operations.py index b3d34b82..9666d40d 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -412,17 +412,8 @@ class ForgeOperationsGGUF(ForgeOperations): return def _apply(self, fn, recurse=True): - if self.weight is not None: - self.weight = utils.tensor2parameter(fn(self.weight)) - if self.bias is not None: - self.bias = utils.tensor2parameter(fn(self.bias)) - for i in range(5): - quant_state_name = f'quant_state_{i}' - quant_state = getattr(self, quant_state_name, None) - if quant_state is not None: - quant_state = fn(quant_state) - quant_state = utils.tensor2parameter(quant_state) - setattr(self, quant_state_name, quant_state) + for k, p in self.named_parameters(recurse=False, remove_duplicate=True): + setattr(self, k, utils.tensor2parameter(fn(p))) return self def forward(self, x): diff --git a/backend/operations_bnb.py b/backend/operations_bnb.py index eca619aa..968f62a6 100644 --- a/backend/operations_bnb.py +++ b/backend/operations_bnb.py @@ -92,10 +92,8 @@ class ForgeLoader4Bit(torch.nn.Module): self.quant_type = quant_type def _apply(self, fn, recurse=True): - if self.weight is not None: - self.weight = utils.tensor2parameter(fn(self.weight)) - if self.bias is not None: - self.bias = utils.tensor2parameter(fn(self.bias)) + for k, p in self.named_parameters(recurse=False, remove_duplicate=True): + setattr(self, k, utils.tensor2parameter(fn(p))) return self def _save_to_state_dict(self, destination, prefix, keep_vars): diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 4e013141..c0d144d5 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -620,8 +620,6 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): if d.device != x.device: d = d.to(device=x.device) - x = x.to(cls.computation_dtype) - return x * d @classmethod