mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
completely solve all LoRA OOMs
This commit is contained in:
@@ -127,6 +127,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
@@ -173,6 +174,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
@@ -210,6 +212,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
@@ -231,6 +234,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
elif patch_type in extra_weight_calculators:
|
||||
weight = extra_weight_calculators[patch_type](weight, strength, v)
|
||||
else:
|
||||
@@ -336,7 +340,12 @@ class LoraLoader:
|
||||
weight = weight.data
|
||||
|
||||
if target_device is not None:
|
||||
weight = weight.to(device=target_device)
|
||||
try:
|
||||
weight = weight.to(device=target_device)
|
||||
except:
|
||||
print('Moving layer weight failed. Retrying by offload models.')
|
||||
self.model.to(device=offload_device)
|
||||
weight = weight.to(device=target_device)
|
||||
|
||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||
|
||||
@@ -348,8 +357,15 @@ class LoraLoader:
|
||||
weight = dequantize_tensor(weight)
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
|
||||
try:
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
except:
|
||||
print('Patching LoRA weights failed. Retrying by offload models.')
|
||||
self.model.to(device=offload_device)
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
|
||||
if bnb_layer is not None:
|
||||
bnb_layer.reload_weight(weight)
|
||||
|
||||
Reference in New Issue
Block a user