mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 17:09:59 +00:00
@@ -236,9 +236,7 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def forge_patch_model(self, target_device):
|
||||
assert isinstance(target_device, torch.device)
|
||||
|
||||
def forge_patch_model(self, device_to=None):
|
||||
for k, item in self.object_patches.items():
|
||||
old = utils.get_attr(self.model, k)
|
||||
|
||||
@@ -264,15 +262,25 @@ class ModelPatcher:
|
||||
assert weight.module is not None, 'BNB bad weight without parent layer!'
|
||||
bnb_layer = weight.module
|
||||
if weight.bnb_quantized:
|
||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(target_device)
|
||||
if device_to is not None:
|
||||
assert device_to.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(device_to)
|
||||
else:
|
||||
weight = weight.cuda()
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
else:
|
||||
weight = weight.data
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=torch.float32, device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
|
||||
to_args = dict(dtype=torch.float32)
|
||||
|
||||
if device_to is not None:
|
||||
to_args['device'] = device_to
|
||||
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
|
||||
|
||||
weight = weight.to(**to_args)
|
||||
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
|
||||
if bnb_layer is not None:
|
||||
@@ -281,12 +289,13 @@ class ModelPatcher:
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
|
||||
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def forge_unpatch_model(self, target_device=None):
|
||||
def forge_unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
@@ -300,9 +309,9 @@ class ModelPatcher:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
|
||||
Reference in New Issue
Block a user