fix possible oom

This commit is contained in:
layerdiffusion
2024-08-14 20:27:05 -07:00
parent cb889470ba
commit bbd0d76b28

View File

@@ -273,15 +273,15 @@ class ModelPatcher:
else:
weight = weight.data
weight_original_dtype = weight.dtype
to_args = dict(dtype=torch.float32)
if device_to is not None:
to_args['device'] = device_to
to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to)
temp_weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
weight = weight.to(**to_args)
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if bnb_layer is not None:
bnb_layer.reload_weight(out_weight)