From bbd0d76b28fa1bef020aa693236eb23014a50deb Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:27:05 -0700 Subject: [PATCH] fix possible oom --- backend/patcher/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 1b40e984..dbeed0fa 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -273,15 +273,15 @@ class ModelPatcher: else: weight = weight.data + weight_original_dtype = weight.dtype to_args = dict(dtype=torch.float32) if device_to is not None: to_args['device'] = device_to to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to) - temp_weight = weight.to(**to_args) - - out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype) + weight = weight.to(**to_args) + out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) if bnb_layer is not None: bnb_layer.reload_weight(out_weight)