speedup lora patching

This commit is contained in:
layerdiffusion
2024-08-15 06:51:52 -07:00
parent 141cf81c23
commit f510f51303

View File

@@ -282,6 +282,9 @@ class ModelPatcher:
else:
weight = weight.data
if target_device is not None:
weight = weight.to(device=target_device)
gguf_cls, gguf_type, gguf_real_shape = None, None, None
if hasattr(weight, 'is_gguf'):
@@ -292,12 +295,7 @@ class ModelPatcher:
weight = dequantize_tensor(weight)
weight_original_dtype = weight.dtype
if target_device is not None:
weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if bnb_layer is not None: