diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 8880b318..642109f4 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -127,6 +127,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) + raise e elif patch_type == "lokr": w1 = v[0] w2 = v[1] @@ -173,6 +174,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) + raise e elif patch_type == "loha": w1a = v[0] w1b = v[1] @@ -210,6 +212,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) + raise e elif patch_type == "glora": if v[4] is not None: alpha = v[4] / v[0].shape[0] @@ -231,6 +234,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) + raise e elif patch_type in extra_weight_calculators: weight = extra_weight_calculators[patch_type](weight, strength, v) else: @@ -336,7 +340,12 @@ class LoraLoader: weight = weight.data if target_device is not None: - weight = weight.to(device=target_device) + try: + weight = weight.to(device=target_device) + except: + print('Moving layer weight failed. Retrying by offload models.') + self.model.to(device=offload_device) + weight = weight.to(device=target_device) gguf_cls, gguf_type, gguf_real_shape = None, None, None @@ -348,8 +357,15 @@ class LoraLoader: weight = dequantize_tensor(weight) weight_original_dtype = weight.dtype - weight = weight.to(dtype=torch.float32) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + + try: + weight = weight.to(dtype=torch.float32) + weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + except: + print('Patching LoRA weights failed. Retrying by offload models.') + self.model.to(device=offload_device) + weight = weight.to(dtype=torch.float32) + weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) if bnb_layer is not None: bnb_layer.reload_weight(weight)