From a29875206f10a3aa98fff0b2e66fce53a33d0bca Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:39:05 -0700 Subject: [PATCH] Revert "simplify codes" This reverts commit e7567efd4b04d06233b6a4c507fe70bd1aa2d536. --- backend/patcher/base.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/backend/patcher/base.py b/backend/patcher/base.py index d5fd476b..dbeed0fa 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -236,9 +236,7 @@ class ModelPatcher: sd.pop(k) return sd - def forge_patch_model(self, target_device): - assert isinstance(target_device, torch.device) - + def forge_patch_model(self, device_to=None): for k, item in self.object_patches.items(): old = utils.get_attr(self.model, k) @@ -264,15 +262,25 @@ class ModelPatcher: assert weight.module is not None, 'BNB bad weight without parent layer!' bnb_layer = weight.module if weight.bnb_quantized: - assert target_device.type == 'cuda', 'BNB Must use CUDA!' - weight = weight.to(target_device) + if device_to is not None: + assert device_to.type == 'cuda', 'BNB Must use CUDA!' + weight = weight.to(device_to) + else: + weight = weight.cuda() + from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) else: weight = weight.data weight_original_dtype = weight.dtype - weight = weight.to(dtype=torch.float32, device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device)) + to_args = dict(dtype=torch.float32) + + if device_to is not None: + to_args['device'] = device_to + to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to) + + weight = weight.to(**to_args) out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) if bnb_layer is not None: @@ -281,12 +289,13 @@ class ModelPatcher: utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False)) - self.model.to(target_device) - self.current_device = target_device + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to return self.model - def forge_unpatch_model(self, target_device=None): + def forge_unpatch_model(self, device_to=None): keys = list(self.backup.keys()) for k in keys: @@ -300,9 +309,9 @@ class ModelPatcher: self.backup = {} - if target_device is not None: - self.model.to(target_device) - self.current_device = target_device + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: