mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Wan22 14b training is working, still need tons of testing and some bug fixes
This commit is contained in:
@@ -185,6 +185,9 @@ class NetworkConfig:
|
||||
self.conv_alpha = 9999999999
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
# for multi stage models
|
||||
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
||||
|
||||
@@ -310,6 +310,7 @@ class Wan21(BaseModel):
|
||||
arch = 'wan21'
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = False
|
||||
_wan_vae_path = None
|
||||
|
||||
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
|
||||
def __init__(
|
||||
@@ -431,8 +432,14 @@ class Wan21(BaseModel):
|
||||
scheduler = Wan21.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
# todo, example does float 32? check if quality suffers
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
|
||||
if self._wan_vae_path is not None:
|
||||
# load the vae from individual repo
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype)
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(save_dict, metadata)
|
||||
# let the model handle the saving
|
||||
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
|
||||
# call the base model save lora method
|
||||
self.base_model_ref().save_lora(save_dict, file, metadata)
|
||||
return
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(save_dict, file, metadata)
|
||||
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if isinstance(file, str):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
|
||||
# call the base model load lora method
|
||||
weights_sd = self.base_model_ref().load_lora(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
else:
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
Reference in New Issue
Block a user