Wan22 14b training is working, still need tons of testing and some bug fixes

This commit is contained in:
Jaret Burkett
2025-08-14 13:03:27 -06:00
parent be71cc75ce
commit 3413fa537f
8 changed files with 554 additions and 24 deletions

View File

@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(save_dict, metadata)
# let the model handle the saving
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
# call the base model save lora method
self.base_model_ref().save_lora(save_dict, file, metadata)
return
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(save_dict, file, metadata)
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
keymap = {} if keymap is None else keymap
if isinstance(file, str):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
# call the base model load lora method
weights_sd = self.base_model_ref().load_lora(file)
else:
weights_sd = torch.load(file, map_location="cpu")
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
else:
# probably a state dict
weights_sd = file