mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Wan22 14b training is working, still need tons of testing and some bug fixes
This commit is contained in:
@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(save_dict, metadata)
|
||||
# let the model handle the saving
|
||||
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
|
||||
# call the base model save lora method
|
||||
self.base_model_ref().save_lora(save_dict, file, metadata)
|
||||
return
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(save_dict, file, metadata)
|
||||
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if isinstance(file, str):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
|
||||
# call the base model load lora method
|
||||
weights_sd = self.base_model_ref().load_lora(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
else:
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
Reference in New Issue
Block a user