Second Attempt for #1502

This commit is contained in:
layerdiffusion
2024-08-28 08:07:42 -07:00
parent 37fcb7bce8
commit 0abb6c4686
4 changed files with 102 additions and 48 deletions

View File

@@ -7,7 +7,7 @@ import torch
import platform import platform
from enum import Enum from enum import Enum
from backend import stream from backend import stream, utils
from backend.args import args from backend.args import args
@@ -337,9 +337,12 @@ def state_dict_dtype(state_dict):
return major_dtype return major_dtype
def module_size(module, exclude_device=None): def module_size(module, exclude_device=None, return_split=False):
module_mem = 0 module_mem = 0
for p in module.parameters(): weight_mem = 0
weight_patterns = ['weight']
for k, p in module.named_parameters():
t = p.data t = p.data
if exclude_device is not None: if exclude_device is not None:
@@ -357,9 +360,66 @@ def module_size(module, exclude_device=None):
element_size = 1.1 # a bit more than 0.5 because of quant state parameters element_size = 1.1 # a bit more than 0.5 because of quant state parameters
module_mem += t.nelement() * element_size module_mem += t.nelement() * element_size
if k in weight_patterns:
weight_mem += t.nelement() * element_size
if return_split:
return module_mem, weight_mem, module_mem - weight_mem
return module_mem return module_mem
def module_move(module, device, recursive=True, excluded_pattens=[]):
if recursive:
return module.to(device=device)
for k, p in module.named_parameters(recurse=False, remove_duplicate=True):
if k in excluded_pattens:
continue
setattr(module, k, utils.tensor2parameter(p.to(device=device)))
return module
def build_module_profile(model, model_gpu_memory_when_using_cpu_swap):
all_modules = []
legacy_modules = []
for m in model.modules():
if hasattr(m, "parameters_manual_cast"):
m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
all_modules.append(m)
elif hasattr(m, "weight"):
m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
legacy_modules.append(m)
gpu_modules = []
gpu_modules_only_extras = []
mem_counter = 0
for m in legacy_modules.copy():
gpu_modules.append(m)
legacy_modules.remove(m)
mem_counter += m.total_mem
for m in sorted(all_modules, key=lambda x: x.extra_mem).copy():
if mem_counter + m.extra_mem < model_gpu_memory_when_using_cpu_swap:
gpu_modules_only_extras.append(m)
all_modules.remove(m)
mem_counter += m.extra_mem
cpu_modules = all_modules
for m in sorted(gpu_modules_only_extras, key=lambda x: x.weight_mem).copy():
if mem_counter + m.weight_mem < model_gpu_memory_when_using_cpu_swap:
gpu_modules.append(m)
gpu_modules_only_extras.remove(m)
mem_counter += m.weight_mem
return gpu_modules, gpu_modules_only_extras, cpu_modules
class LoadedModel: class LoadedModel:
def __init__(self, model, memory_required): def __init__(self, model, memory_required):
self.model = model self.model = model
@@ -392,43 +452,50 @@ class LoadedModel:
raise e raise e
if not do_not_need_cpu_swap: if not do_not_need_cpu_swap:
memory_in_swap = 0 gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
mem_counter = 0 mem_counter = 0
mem_cannot_cast = 0 swap_counter = 0
for m in self.real_model.modules():
if hasattr(m, "parameters_manual_cast"): for m in gpu_modules:
m.prev_parameters_manual_cast = m.parameters_manual_cast m.to(self.device)
m.parameters_manual_cast = True mem_counter += m.total_mem
module_mem = module_size(m)
if mem_counter + module_mem < model_gpu_memory_when_using_cpu_swap: for m in cpu_modules + gpu_modules_only_extras:
m.to(self.device) if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
mem_counter += module_mem m.to(self.device) # Quantize happens here
for m in cpu_modules:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
m.to(self.model.offload_device)
if pin_memory:
m._apply(lambda x: x.pin_memory())
swap_counter += m.total_mem
for m in gpu_modules_only_extras:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
if hasattr(m, 'weight') and m.weight is not None:
if pin_memory:
m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device).pin_memory())
else: else:
memory_in_swap += module_mem m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device))
mem_counter += m.extra_mem
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda': swap_counter += m.weight_mem
m.to(self.device) # Quantize happens here
m.to(self.model.offload_device)
if PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device):
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"):
m.to(self.device)
module_mem = module_size(m)
mem_counter += module_mem
mem_cannot_cast += module_mem
if mem_cannot_cast > 0:
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU' swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked' method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
print(f"[Memory Management] Loaded to {swap_flag} Swap: {memory_in_swap / (1024 * 1024):.2f} MB ({method_flag} method)") print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)")
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB") print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
self.model_accelerated = True self.model_accelerated = True
global signal_empty_cache
signal_empty_cache = True
if is_intel_xpu() and not args.disable_ipex_hijack: if is_intel_xpu() and not args.disable_ipex_hijack:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)

View File

@@ -412,17 +412,8 @@ class ForgeOperationsGGUF(ForgeOperations):
return return
def _apply(self, fn, recurse=True): def _apply(self, fn, recurse=True):
if self.weight is not None: for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
self.weight = utils.tensor2parameter(fn(self.weight)) setattr(self, k, utils.tensor2parameter(fn(p)))
if self.bias is not None:
self.bias = utils.tensor2parameter(fn(self.bias))
for i in range(5):
quant_state_name = f'quant_state_{i}'
quant_state = getattr(self, quant_state_name, None)
if quant_state is not None:
quant_state = fn(quant_state)
quant_state = utils.tensor2parameter(quant_state)
setattr(self, quant_state_name, quant_state)
return self return self
def forward(self, x): def forward(self, x):

View File

@@ -92,10 +92,8 @@ class ForgeLoader4Bit(torch.nn.Module):
self.quant_type = quant_type self.quant_type = quant_type
def _apply(self, fn, recurse=True): def _apply(self, fn, recurse=True):
if self.weight is not None: for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
self.weight = utils.tensor2parameter(fn(self.weight)) setattr(self, k, utils.tensor2parameter(fn(p)))
if self.bias is not None:
self.bias = utils.tensor2parameter(fn(self.bias))
return self return self
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):

View File

@@ -620,8 +620,6 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
if d.device != x.device: if d.device != x.device:
d = d.to(device=x.device) d = d.to(device=x.device)
x = x.to(cls.computation_dtype)
return x * d return x * d
@classmethod @classmethod