Wan22 14b training is working, still need tons of testing and some bug fixes

This commit is contained in:
Jaret Burkett
2025-08-14 13:03:27 -06:00
parent be71cc75ce
commit 3413fa537f
8 changed files with 554 additions and 24 deletions

View File

@@ -185,6 +185,9 @@ class NetworkConfig:
self.conv_alpha = 9999999999
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
# for multi stage models
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']

View File

@@ -310,6 +310,7 @@ class Wan21(BaseModel):
arch = 'wan21'
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = False
_wan_vae_path = None
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
def __init__(
@@ -431,8 +432,14 @@ class Wan21(BaseModel):
scheduler = Wan21.get_train_scheduler()
self.print_and_status_update("Loading VAE")
# todo, example does float 32? check if quality suffers
vae = AutoencoderKLWan.from_pretrained(
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
if self._wan_vae_path is not None:
# load the vae from individual repo
vae = AutoencoderKLWan.from_pretrained(
self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype)
else:
vae = AutoencoderKLWan.from_pretrained(
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
flush()
self.print_and_status_update("Making pipe")

View File

@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(save_dict, metadata)
# let the model handle the saving
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
# call the base model save lora method
self.base_model_ref().save_lora(save_dict, file, metadata)
return
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(save_dict, file, metadata)
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
keymap = {} if keymap is None else keymap
if isinstance(file, str):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
# call the base model load lora method
weights_sd = self.base_model_ref().load_lora(file)
else:
weights_sd = torch.load(file, map_location="cpu")
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
else:
# probably a state dict
weights_sd = file