mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
speed up lora using cuda profile
This commit is contained in:
@@ -375,9 +375,9 @@ class LoadedModel:
|
|||||||
self.model.model_patches_to(self.model.model_dtype())
|
self.model.model_patches_to(self.model.model_dtype())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to)
|
self.real_model = self.model.forge_patch_model(device_to=patch_model_to)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.forge_unpatch_model(self.model.offload_device)
|
||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -429,9 +429,9 @@ class LoadedModel:
|
|||||||
self.model_accelerated = False
|
self.model_accelerated = False
|
||||||
|
|
||||||
if avoid_model_moving:
|
if avoid_model_moving:
|
||||||
self.model.unpatch_model()
|
self.model.forge_unpatch_model()
|
||||||
else:
|
else:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.forge_unpatch_model(self.model.offload_device)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import torch
|
|||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from backend import memory_management, utils
|
from backend import memory_management, utils, operations
|
||||||
from backend.patcher.lora import merge_lora_to_model_weight
|
from backend.patcher.lora import merge_lora_to_model_weight
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
@@ -63,8 +63,6 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
self.current_device = current_device
|
self.current_device = current_device
|
||||||
|
|
||||||
self.weight_inplace_update = weight_inplace_update
|
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
@@ -72,7 +70,7 @@ class ModelPatcher:
|
|||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@@ -236,39 +234,40 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self, device_to=None):
|
def forge_patch_model(self, device_to=None):
|
||||||
for k in self.object_patches:
|
for k, item in self.object_patches.items():
|
||||||
old = utils.get_attr(self.model, k)
|
old = utils.get_attr(self.model, k)
|
||||||
|
|
||||||
if k not in self.object_patches_backup:
|
if k not in self.object_patches_backup:
|
||||||
self.object_patches_backup[k] = old
|
self.object_patches_backup[k] = old
|
||||||
utils.set_attr_raw(self.model, k, self.object_patches[k])
|
|
||||||
|
|
||||||
model_state_dict = self.model_state_dict()
|
utils.set_attr_raw(self.model, k, item)
|
||||||
|
|
||||||
for key, current_patches in self.patches.items():
|
for key, current_patches in self.patches.items():
|
||||||
assert key in model_state_dict, f"Wrong LoRA Key: {key}"
|
try:
|
||||||
|
weight = utils.get_attr(self.model, key)
|
||||||
weight = model_state_dict[key]
|
assert isinstance(weight, torch.nn.Parameter)
|
||||||
|
except:
|
||||||
if weight.dtype == torch.uint8:
|
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||||
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
|
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
self.backup[key] = weight.to(device=self.offload_device)
|
||||||
|
|
||||||
|
if operations.bnb_avaliable:
|
||||||
|
if hasattr(weight, 'bnb_quantized'):
|
||||||
|
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
|
||||||
|
|
||||||
|
to_args = dict(dtype=torch.float32)
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
to_args['device'] = device_to
|
||||||
else:
|
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
|
||||||
|
temp_weight = weight.to(**to_args)
|
||||||
|
|
||||||
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
|
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
|
||||||
|
|
||||||
if inplace_update:
|
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
|
||||||
utils.copy_to_param(self.model, key, out_weight)
|
|
||||||
else:
|
|
||||||
utils.set_attr(self.model, key, out_weight)
|
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
@@ -276,15 +275,17 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None):
|
def forge_unpatch_model(self, device_to=None):
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
if self.weight_inplace_update:
|
for k in keys:
|
||||||
for k in keys:
|
w = self.backup[k]
|
||||||
utils.copy_to_param(self.model, k, self.backup[k])
|
|
||||||
else:
|
if not isinstance(w, torch.nn.Parameter):
|
||||||
for k in keys:
|
# In very few cases
|
||||||
utils.set_attr(self.model, k, self.backup[k])
|
w = torch.nn.Parameter(w, requires_grad=False)
|
||||||
|
|
||||||
|
utils.set_attr_raw(self.model, k, w)
|
||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
@@ -297,3 +298,4 @@ class ModelPatcher:
|
|||||||
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||||
|
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
|
return
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ class UnetPatcher(ModelPatcher):
|
|||||||
self.extra_concat_condition = None
|
self.extra_concat_condition = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||||
weight_inplace_update=self.weight_inplace_update)
|
|
||||||
|
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
|
|||||||
Reference in New Issue
Block a user