mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-28 10:43:58 +00:00
speed up lora using cuda profile
This commit is contained in:
@@ -375,9 +375,9 @@ class LoadedModel:
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to)
|
||||
self.real_model = self.model.forge_patch_model(device_to=patch_model_to)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model.forge_unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
@@ -429,9 +429,9 @@ class LoadedModel:
|
||||
self.model_accelerated = False
|
||||
|
||||
if avoid_model_moving:
|
||||
self.model.unpatch_model()
|
||||
self.model.forge_unpatch_model()
|
||||
else:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model.forge_unpatch_model(self.model.offload_device)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
from backend import memory_management, utils
|
||||
from backend import memory_management, utils, operations
|
||||
from backend.patcher.lora import merge_lora_to_model_weight
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
@@ -63,8 +63,6 @@ class ModelPatcher:
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@@ -72,7 +70,7 @@ class ModelPatcher:
|
||||
return self.size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -236,39 +234,40 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None):
|
||||
for k in self.object_patches:
|
||||
def forge_patch_model(self, device_to=None):
|
||||
for k, item in self.object_patches.items():
|
||||
old = utils.get_attr(self.model, k)
|
||||
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
utils.set_attr_raw(self.model, k, self.object_patches[k])
|
||||
|
||||
model_state_dict = self.model_state_dict()
|
||||
utils.set_attr_raw(self.model, k, item)
|
||||
|
||||
for key, current_patches in self.patches.items():
|
||||
assert key in model_state_dict, f"Wrong LoRA Key: {key}"
|
||||
|
||||
weight = model_state_dict[key]
|
||||
|
||||
if weight.dtype == torch.uint8:
|
||||
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
try:
|
||||
weight = utils.get_attr(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
self.backup[key] = weight.to(device=self.offload_device)
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
if hasattr(weight, 'bnb_quantized'):
|
||||
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
|
||||
|
||||
to_args = dict(dtype=torch.float32)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
to_args['device'] = device_to
|
||||
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
|
||||
|
||||
temp_weight = weight.to(**to_args)
|
||||
|
||||
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
|
||||
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr(self.model, key, out_weight)
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
@@ -276,15 +275,17 @@ class ModelPatcher:
|
||||
|
||||
return self.model
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
def forge_unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
utils.set_attr(self.model, k, self.backup[k])
|
||||
for k in keys:
|
||||
w = self.backup[k]
|
||||
|
||||
if not isinstance(w, torch.nn.Parameter):
|
||||
# In very few cases
|
||||
w = torch.nn.Parameter(w, requires_grad=False)
|
||||
|
||||
utils.set_attr_raw(self.model, k, w)
|
||||
|
||||
self.backup = {}
|
||||
|
||||
@@ -297,3 +298,4 @@ class ModelPatcher:
|
||||
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
return
|
||||
|
||||
@@ -24,8 +24,7 @@ class UnetPatcher(ModelPatcher):
|
||||
self.extra_concat_condition = None
|
||||
|
||||
def clone(self):
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||
weight_inplace_update=self.weight_inplace_update)
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
|
||||
Reference in New Issue
Block a user