mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-01 22:09:46 +00:00
Second Attempt for #1502
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
import platform
|
||||
|
||||
from enum import Enum
|
||||
from backend import stream
|
||||
from backend import stream, utils
|
||||
from backend.args import args
|
||||
|
||||
|
||||
@@ -337,9 +337,12 @@ def state_dict_dtype(state_dict):
|
||||
return major_dtype
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
def module_size(module, exclude_device=None, return_split=False):
|
||||
module_mem = 0
|
||||
for p in module.parameters():
|
||||
weight_mem = 0
|
||||
weight_patterns = ['weight']
|
||||
|
||||
for k, p in module.named_parameters():
|
||||
t = p.data
|
||||
|
||||
if exclude_device is not None:
|
||||
@@ -357,9 +360,66 @@ def module_size(module, exclude_device=None):
|
||||
element_size = 1.1 # a bit more than 0.5 because of quant state parameters
|
||||
|
||||
module_mem += t.nelement() * element_size
|
||||
|
||||
if k in weight_patterns:
|
||||
weight_mem += t.nelement() * element_size
|
||||
|
||||
if return_split:
|
||||
return module_mem, weight_mem, module_mem - weight_mem
|
||||
|
||||
return module_mem
|
||||
|
||||
|
||||
def module_move(module, device, recursive=True, excluded_pattens=[]):
|
||||
if recursive:
|
||||
return module.to(device=device)
|
||||
|
||||
for k, p in module.named_parameters(recurse=False, remove_duplicate=True):
|
||||
if k in excluded_pattens:
|
||||
continue
|
||||
setattr(module, k, utils.tensor2parameter(p.to(device=device)))
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def build_module_profile(model, model_gpu_memory_when_using_cpu_swap):
|
||||
all_modules = []
|
||||
legacy_modules = []
|
||||
|
||||
for m in model.modules():
|
||||
if hasattr(m, "parameters_manual_cast"):
|
||||
m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
|
||||
all_modules.append(m)
|
||||
elif hasattr(m, "weight"):
|
||||
m.total_mem, m.weight_mem, m.extra_mem = module_size(m, return_split=True)
|
||||
legacy_modules.append(m)
|
||||
|
||||
gpu_modules = []
|
||||
gpu_modules_only_extras = []
|
||||
mem_counter = 0
|
||||
|
||||
for m in legacy_modules.copy():
|
||||
gpu_modules.append(m)
|
||||
legacy_modules.remove(m)
|
||||
mem_counter += m.total_mem
|
||||
|
||||
for m in sorted(all_modules, key=lambda x: x.extra_mem).copy():
|
||||
if mem_counter + m.extra_mem < model_gpu_memory_when_using_cpu_swap:
|
||||
gpu_modules_only_extras.append(m)
|
||||
all_modules.remove(m)
|
||||
mem_counter += m.extra_mem
|
||||
|
||||
cpu_modules = all_modules
|
||||
|
||||
for m in sorted(gpu_modules_only_extras, key=lambda x: x.weight_mem).copy():
|
||||
if mem_counter + m.weight_mem < model_gpu_memory_when_using_cpu_swap:
|
||||
gpu_modules.append(m)
|
||||
gpu_modules_only_extras.remove(m)
|
||||
mem_counter += m.weight_mem
|
||||
|
||||
return gpu_modules, gpu_modules_only_extras, cpu_modules
|
||||
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model, memory_required):
|
||||
self.model = model
|
||||
@@ -392,43 +452,50 @@ class LoadedModel:
|
||||
raise e
|
||||
|
||||
if not do_not_need_cpu_swap:
|
||||
memory_in_swap = 0
|
||||
gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
|
||||
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
|
||||
|
||||
mem_counter = 0
|
||||
mem_cannot_cast = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "parameters_manual_cast"):
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
m.parameters_manual_cast = True
|
||||
module_mem = module_size(m)
|
||||
if mem_counter + module_mem < model_gpu_memory_when_using_cpu_swap:
|
||||
m.to(self.device)
|
||||
mem_counter += module_mem
|
||||
swap_counter = 0
|
||||
|
||||
for m in gpu_modules:
|
||||
m.to(self.device)
|
||||
mem_counter += m.total_mem
|
||||
|
||||
for m in cpu_modules + gpu_modules_only_extras:
|
||||
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
|
||||
m.to(self.device) # Quantize happens here
|
||||
|
||||
for m in cpu_modules:
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
m.parameters_manual_cast = True
|
||||
m.to(self.model.offload_device)
|
||||
if pin_memory:
|
||||
m._apply(lambda x: x.pin_memory())
|
||||
swap_counter += m.total_mem
|
||||
|
||||
for m in gpu_modules_only_extras:
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
m.parameters_manual_cast = True
|
||||
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
if pin_memory:
|
||||
m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device).pin_memory())
|
||||
else:
|
||||
memory_in_swap += module_mem
|
||||
|
||||
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
|
||||
m.to(self.device) # Quantize happens here
|
||||
|
||||
m.to(self.model.offload_device)
|
||||
|
||||
if PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device):
|
||||
m._apply(lambda x: x.pin_memory())
|
||||
elif hasattr(m, "weight"):
|
||||
m.to(self.device)
|
||||
module_mem = module_size(m)
|
||||
mem_counter += module_mem
|
||||
mem_cannot_cast += module_mem
|
||||
|
||||
if mem_cannot_cast > 0:
|
||||
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
|
||||
m.weight = utils.tensor2parameter(m.weight.to(self.model.offload_device))
|
||||
mem_counter += m.extra_mem
|
||||
swap_counter += m.weight_mem
|
||||
|
||||
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
||||
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
||||
print(f"[Memory Management] Loaded to {swap_flag} Swap: {memory_in_swap / (1024 * 1024):.2f} MB ({method_flag} method)")
|
||||
print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)")
|
||||
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
global signal_empty_cache
|
||||
signal_empty_cache = True
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_hijack:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
|
||||
|
||||
@@ -412,17 +412,8 @@ class ForgeOperationsGGUF(ForgeOperations):
|
||||
return
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = utils.tensor2parameter(fn(self.weight))
|
||||
if self.bias is not None:
|
||||
self.bias = utils.tensor2parameter(fn(self.bias))
|
||||
for i in range(5):
|
||||
quant_state_name = f'quant_state_{i}'
|
||||
quant_state = getattr(self, quant_state_name, None)
|
||||
if quant_state is not None:
|
||||
quant_state = fn(quant_state)
|
||||
quant_state = utils.tensor2parameter(quant_state)
|
||||
setattr(self, quant_state_name, quant_state)
|
||||
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
|
||||
setattr(self, k, utils.tensor2parameter(fn(p)))
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -92,10 +92,8 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
self.quant_type = quant_type
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = utils.tensor2parameter(fn(self.weight))
|
||||
if self.bias is not None:
|
||||
self.bias = utils.tensor2parameter(fn(self.bias))
|
||||
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
|
||||
setattr(self, k, utils.tensor2parameter(fn(p)))
|
||||
return self
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
|
||||
2
packages_3rdparty/gguf/quants.py
vendored
2
packages_3rdparty/gguf/quants.py
vendored
@@ -620,8 +620,6 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
if d.device != x.device:
|
||||
d = d.to(device=x.device)
|
||||
|
||||
x = x.to(cls.computation_dtype)
|
||||
|
||||
return x * d
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user