diff --git a/backend/memory_management.py b/backend/memory_management.py index b3fd3d9b..86e9f09f 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -375,9 +375,9 @@ class LoadedModel: self.model.model_patches_to(self.model.model_dtype()) try: - self.real_model = self.model.patch_model(device_to=patch_model_to) + self.real_model = self.model.forge_patch_model(device_to=patch_model_to) except Exception as e: - self.model.unpatch_model(self.model.offload_device) + self.model.forge_unpatch_model(self.model.offload_device) self.model_unload() raise e @@ -429,9 +429,9 @@ class LoadedModel: self.model_accelerated = False if avoid_model_moving: - self.model.unpatch_model() + self.model.forge_unpatch_model() else: - self.model.unpatch_model(self.model.offload_device) + self.model.forge_unpatch_model(self.model.offload_device) self.model.model_patches_to(self.model.offload_device) def __eq__(self, other): diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 074a0e83..affd0a73 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -6,7 +6,7 @@ import torch import copy import inspect -from backend import memory_management, utils +from backend import memory_management, utils, operations from backend.patcher.lora import merge_lora_to_model_weight @@ -47,7 +47,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): + def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs): self.size = size self.model = model self.patches = {} @@ -63,8 +63,6 @@ class ModelPatcher: else: self.current_device = current_device - self.weight_inplace_update = weight_inplace_update - def model_size(self): if self.size > 0: return self.size @@ -72,7 +70,7 @@ class ModelPatcher: return self.size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -236,39 +234,40 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self, device_to=None): - for k in self.object_patches: + def forge_patch_model(self, device_to=None): + for k, item in self.object_patches.items(): old = utils.get_attr(self.model, k) + if k not in self.object_patches_backup: self.object_patches_backup[k] = old - utils.set_attr_raw(self.model, k, self.object_patches[k]) - model_state_dict = self.model_state_dict() + utils.set_attr_raw(self.model, k, item) for key, current_patches in self.patches.items(): - assert key in model_state_dict, f"Wrong LoRA Key: {key}" - - weight = model_state_dict[key] - - if weight.dtype == torch.uint8: - raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!') - - inplace_update = self.weight_inplace_update + try: + weight = utils.get_attr(self.model, key) + assert isinstance(weight, torch.nn.Parameter) + except: + raise ValueError(f"Wrong LoRA Key: {key}") if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + self.backup[key] = weight.to(device=self.offload_device) + + if operations.bnb_avaliable: + if hasattr(weight, 'bnb_quantized'): + raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!') + + to_args = dict(dtype=torch.float32) if device_to is not None: - temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) + to_args['device'] = device_to + to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to) + + temp_weight = weight.to(**to_args) out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype) - if inplace_update: - utils.copy_to_param(self.model, key, out_weight) - else: - utils.set_attr(self.model, key, out_weight) + utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False)) if device_to is not None: self.model.to(device_to) @@ -276,15 +275,17 @@ class ModelPatcher: return self.model - def unpatch_model(self, device_to=None): + def forge_unpatch_model(self, device_to=None): keys = list(self.backup.keys()) - if self.weight_inplace_update: - for k in keys: - utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - utils.set_attr(self.model, k, self.backup[k]) + for k in keys: + w = self.backup[k] + + if not isinstance(w, torch.nn.Parameter): + # In very few cases + w = torch.nn.Parameter(w, requires_grad=False) + + utils.set_attr_raw(self.model, k, w) self.backup = {} @@ -297,3 +298,4 @@ class ModelPatcher: utils.set_attr_raw(self.model, k, self.object_patches_backup[k]) self.object_patches_backup = {} + return diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index e8416f70..eeaea671 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -24,8 +24,7 @@ class UnetPatcher(ModelPatcher): self.extra_concat_condition = None def clone(self): - n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, - weight_inplace_update=self.weight_inplace_update) + n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) n.patches = {} for k in self.patches: