speed up lora using cuda profile

This commit is contained in:
layerdiffusion
2024-08-14 19:09:35 -07:00
parent df0fee9396
commit aff742b597
3 changed files with 39 additions and 38 deletions

View File

@@ -375,9 +375,9 @@ class LoadedModel:
self.model.model_patches_to(self.model.model_dtype()) self.model.model_patches_to(self.model.model_dtype())
try: try:
self.real_model = self.model.patch_model(device_to=patch_model_to) self.real_model = self.model.forge_patch_model(device_to=patch_model_to)
except Exception as e: except Exception as e:
self.model.unpatch_model(self.model.offload_device) self.model.forge_unpatch_model(self.model.offload_device)
self.model_unload() self.model_unload()
raise e raise e
@@ -429,9 +429,9 @@ class LoadedModel:
self.model_accelerated = False self.model_accelerated = False
if avoid_model_moving: if avoid_model_moving:
self.model.unpatch_model() self.model.forge_unpatch_model()
else: else:
self.model.unpatch_model(self.model.offload_device) self.model.forge_unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other): def __eq__(self, other):

View File

@@ -6,7 +6,7 @@ import torch
import copy import copy
import inspect import inspect
from backend import memory_management, utils from backend import memory_management, utils, operations
from backend.patcher.lora import merge_lora_to_model_weight from backend.patcher.lora import merge_lora_to_model_weight
@@ -47,7 +47,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
self.size = size self.size = size
self.model = model self.model = model
self.patches = {} self.patches = {}
@@ -63,8 +63,6 @@ class ModelPatcher:
else: else:
self.current_device = current_device self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
return self.size return self.size
@@ -72,7 +70,7 @@ class ModelPatcher:
return self.size return self.size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@@ -236,39 +234,40 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_model(self, device_to=None): def forge_patch_model(self, device_to=None):
for k in self.object_patches: for k, item in self.object_patches.items():
old = utils.get_attr(self.model, k) old = utils.get_attr(self.model, k)
if k not in self.object_patches_backup: if k not in self.object_patches_backup:
self.object_patches_backup[k] = old self.object_patches_backup[k] = old
utils.set_attr_raw(self.model, k, self.object_patches[k])
model_state_dict = self.model_state_dict() utils.set_attr_raw(self.model, k, item)
for key, current_patches in self.patches.items(): for key, current_patches in self.patches.items():
assert key in model_state_dict, f"Wrong LoRA Key: {key}" try:
weight = utils.get_attr(self.model, key)
weight = model_state_dict[key] assert isinstance(weight, torch.nn.Parameter)
except:
if weight.dtype == torch.uint8: raise ValueError(f"Wrong LoRA Key: {key}")
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
inplace_update = self.weight_inplace_update
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) self.backup[key] = weight.to(device=self.offload_device)
if operations.bnb_avaliable:
if hasattr(weight, 'bnb_quantized'):
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
to_args = dict(dtype=torch.float32)
if device_to is not None: if device_to is not None:
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True) to_args['device'] = device_to
else: to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
temp_weight = weight.to(torch.float32, copy=True)
temp_weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype) out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
if inplace_update: utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
@@ -276,15 +275,17 @@ class ModelPatcher:
return self.model return self.model
def unpatch_model(self, device_to=None): def forge_unpatch_model(self, device_to=None):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
if self.weight_inplace_update: for k in keys:
for k in keys: w = self.backup[k]
utils.copy_to_param(self.model, k, self.backup[k])
else: if not isinstance(w, torch.nn.Parameter):
for k in keys: # In very few cases
utils.set_attr(self.model, k, self.backup[k]) w = torch.nn.Parameter(w, requires_grad=False)
utils.set_attr_raw(self.model, k, w)
self.backup = {} self.backup = {}
@@ -297,3 +298,4 @@ class ModelPatcher:
utils.set_attr_raw(self.model, k, self.object_patches_backup[k]) utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {} self.object_patches_backup = {}
return

View File

@@ -24,8 +24,7 @@ class UnetPatcher(ModelPatcher):
self.extra_concat_condition = None self.extra_concat_condition = None
def clone(self): def clone(self):
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches: