simplify codes

This commit is contained in:
layerdiffusion
2024-08-14 20:48:39 -07:00
parent 4b66cf1126
commit 59790f2cb4
2 changed files with 16 additions and 16 deletions

View File

@@ -375,7 +375,7 @@ class LoadedModel:
self.model.model_patches_to(self.model.model_dtype())
try:
self.real_model = self.model.forge_patch_model(device_to=patch_model_to)
self.real_model = self.model.forge_patch_model(patch_model_to)
except Exception as e:
self.model.forge_unpatch_model(self.model.offload_device)
self.model_unload()

View File

@@ -236,7 +236,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def forge_patch_model(self, device_to=None):
def forge_patch_model(self, target_device=None):
for k, item in self.object_patches.items():
old = utils.get_attr(self.model, k)
@@ -264,16 +264,16 @@ class ModelPatcher:
if weight.bnb_quantized:
weight_original_device = weight.device
if device_to is not None:
assert device_to.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(device_to)
if target_device is not None:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
else:
weight = weight.cuda()
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
if device_to is None:
if target_device is None:
weight = weight.to(device=weight_original_device)
else:
weight = weight.data
@@ -281,9 +281,9 @@ class ModelPatcher:
weight_original_dtype = weight.dtype
to_args = dict(dtype=torch.float32)
if device_to is not None:
to_args['device'] = device_to
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
if target_device is not None:
to_args['device'] = target_device
to_args['non_blocking'] = memory_management.device_supports_non_blocking(target_device)
weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
@@ -294,13 +294,13 @@ class ModelPatcher:
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
if target_device is not None:
self.model.to(target_device)
self.current_device = target_device
return self.model
def forge_unpatch_model(self, device_to=None):
def forge_unpatch_model(self, target_device=None):
keys = list(self.backup.keys())
for k in keys:
@@ -314,9 +314,9 @@ class ModelPatcher:
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
if target_device is not None:
self.model.to(target_device)
self.current_device = target_device
keys = list(self.object_patches_backup.keys())
for k in keys: