mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-23 00:03:57 +00:00
simplify codes
This commit is contained in:
@@ -375,7 +375,7 @@ class LoadedModel:
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
try:
|
||||
self.real_model = self.model.forge_patch_model(device_to=patch_model_to)
|
||||
self.real_model = self.model.forge_patch_model(patch_model_to)
|
||||
except Exception as e:
|
||||
self.model.forge_unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
|
||||
@@ -236,7 +236,7 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def forge_patch_model(self, device_to=None):
|
||||
def forge_patch_model(self, target_device=None):
|
||||
for k, item in self.object_patches.items():
|
||||
old = utils.get_attr(self.model, k)
|
||||
|
||||
@@ -264,16 +264,16 @@ class ModelPatcher:
|
||||
if weight.bnb_quantized:
|
||||
weight_original_device = weight.device
|
||||
|
||||
if device_to is not None:
|
||||
assert device_to.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(device_to)
|
||||
if target_device is not None:
|
||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(target_device)
|
||||
else:
|
||||
weight = weight.cuda()
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
if device_to is None:
|
||||
if target_device is None:
|
||||
weight = weight.to(device=weight_original_device)
|
||||
else:
|
||||
weight = weight.data
|
||||
@@ -281,9 +281,9 @@ class ModelPatcher:
|
||||
weight_original_dtype = weight.dtype
|
||||
to_args = dict(dtype=torch.float32)
|
||||
|
||||
if device_to is not None:
|
||||
to_args['device'] = device_to
|
||||
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
|
||||
if target_device is not None:
|
||||
to_args['device'] = target_device
|
||||
to_args['non_blocking'] = memory_management.device_supports_non_blocking(target_device)
|
||||
|
||||
weight = weight.to(**to_args)
|
||||
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
@@ -294,13 +294,13 @@ class ModelPatcher:
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
return self.model
|
||||
|
||||
def forge_unpatch_model(self, device_to=None):
|
||||
def forge_unpatch_model(self, target_device=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
@@ -314,9 +314,9 @@ class ModelPatcher:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
|
||||
Reference in New Issue
Block a user