Revert "simplify codes"

This reverts commit e7567efd4b.
This commit is contained in:
layerdiffusion
2024-08-14 20:39:05 -07:00
parent b31f81628f
commit a29875206f

View File

@@ -236,9 +236,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def forge_patch_model(self, target_device):
assert isinstance(target_device, torch.device)
def forge_patch_model(self, device_to=None):
for k, item in self.object_patches.items():
old = utils.get_attr(self.model, k)
@@ -264,15 +262,25 @@ class ModelPatcher:
assert weight.module is not None, 'BNB bad weight without parent layer!'
bnb_layer = weight.module
if weight.bnb_quantized:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
if device_to is not None:
assert device_to.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(device_to)
else:
weight = weight.cuda()
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
else:
weight = weight.data
weight_original_dtype = weight.dtype
weight = weight.to(dtype=torch.float32, device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
to_args = dict(dtype=torch.float32)
if device_to is not None:
to_args['device'] = device_to
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if bnb_layer is not None:
@@ -281,12 +289,13 @@ class ModelPatcher:
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
self.model.to(target_device)
self.current_device = target_device
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def forge_unpatch_model(self, target_device=None):
def forge_unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
@@ -300,9 +309,9 @@ class ModelPatcher:
self.backup = {}
if target_device is not None:
self.model.to(target_device)
self.current_device = target_device
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys: