speed up lora using cuda profile

This commit is contained in:
layerdiffusion
2024-08-14 19:09:35 -07:00
parent df0fee9396
commit aff742b597
3 changed files with 39 additions and 38 deletions

View File

@@ -375,9 +375,9 @@ class LoadedModel:
self.model.model_patches_to(self.model.model_dtype())
try:
self.real_model = self.model.patch_model(device_to=patch_model_to)
self.real_model = self.model.forge_patch_model(device_to=patch_model_to)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model.forge_unpatch_model(self.model.offload_device)
self.model_unload()
raise e
@@ -429,9 +429,9 @@ class LoadedModel:
self.model_accelerated = False
if avoid_model_moving:
self.model.unpatch_model()
self.model.forge_unpatch_model()
else:
self.model.unpatch_model(self.model.offload_device)
self.model.forge_unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other):

View File

@@ -6,7 +6,7 @@ import torch
import copy
import inspect
from backend import memory_management, utils
from backend import memory_management, utils, operations
from backend.patcher.lora import merge_lora_to_model_weight
@@ -47,7 +47,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
self.size = size
self.model = model
self.patches = {}
@@ -63,8 +63,6 @@ class ModelPatcher:
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
def model_size(self):
if self.size > 0:
return self.size
@@ -72,7 +70,7 @@ class ModelPatcher:
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -236,39 +234,40 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_model(self, device_to=None):
for k in self.object_patches:
def forge_patch_model(self, device_to=None):
for k, item in self.object_patches.items():
old = utils.get_attr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
utils.set_attr_raw(self.model, k, self.object_patches[k])
model_state_dict = self.model_state_dict()
utils.set_attr_raw(self.model, k, item)
for key, current_patches in self.patches.items():
assert key in model_state_dict, f"Wrong LoRA Key: {key}"
weight = model_state_dict[key]
if weight.dtype == torch.uint8:
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
inplace_update = self.weight_inplace_update
try:
weight = utils.get_attr(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
self.backup[key] = weight.to(device=self.offload_device)
if operations.bnb_avaliable:
if hasattr(weight, 'bnb_quantized'):
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
to_args = dict(dtype=torch.float32)
if device_to is not None:
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
to_args['device'] = device_to
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
temp_weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
if device_to is not None:
self.model.to(device_to)
@@ -276,15 +275,17 @@ class ModelPatcher:
return self.model
def unpatch_model(self, device_to=None):
def forge_unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
utils.set_attr(self.model, k, self.backup[k])
for k in keys:
w = self.backup[k]
if not isinstance(w, torch.nn.Parameter):
# In very few cases
w = torch.nn.Parameter(w, requires_grad=False)
utils.set_attr_raw(self.model, k, w)
self.backup = {}
@@ -297,3 +298,4 @@ class ModelPatcher:
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}
return

View File

@@ -24,8 +24,7 @@ class UnetPatcher(ModelPatcher):
self.extra_concat_condition = None
def clone(self):
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
weight_inplace_update=self.weight_inplace_update)
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches: