completely solve all LoRA OOMs

This commit is contained in:
layerdiffusion
2024-08-17 22:43:20 -07:00
parent 93bfd7f85b
commit db5a876d4c

View File

@@ -127,6 +127,7 @@ def merge_lora_to_model_weight(patches, weight, key):
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
@@ -173,6 +174,7 @@ def merge_lora_to_model_weight(patches, weight, key):
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
@@ -210,6 +212,7 @@ def merge_lora_to_model_weight(patches, weight, key):
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
@@ -231,6 +234,7 @@ def merge_lora_to_model_weight(patches, weight, key):
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type in extra_weight_calculators:
weight = extra_weight_calculators[patch_type](weight, strength, v)
else:
@@ -336,7 +340,12 @@ class LoraLoader:
weight = weight.data
if target_device is not None:
weight = weight.to(device=target_device)
try:
weight = weight.to(device=target_device)
except:
print('Moving layer weight failed. Retrying by offload models.')
self.model.to(device=offload_device)
weight = weight.to(device=target_device)
gguf_cls, gguf_type, gguf_real_shape = None, None, None
@@ -348,8 +357,15 @@ class LoraLoader:
weight = dequantize_tensor(weight)
weight_original_dtype = weight.dtype
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
try:
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
except:
print('Patching LoRA weights failed. Retrying by offload models.')
self.model.to(device=offload_device)
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if bnb_layer is not None:
bnb_layer.reload_weight(weight)