mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Wan22 14b training is working, still need tons of testing and some bug fixes
This commit is contained in:
@@ -310,6 +310,7 @@ class Wan21(BaseModel):
|
||||
arch = 'wan21'
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = False
|
||||
_wan_vae_path = None
|
||||
|
||||
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
|
||||
def __init__(
|
||||
@@ -431,8 +432,14 @@ class Wan21(BaseModel):
|
||||
scheduler = Wan21.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
# todo, example does float 32? check if quality suffers
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
|
||||
if self._wan_vae_path is not None:
|
||||
# load the vae from individual repo
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype)
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
Reference in New Issue
Block a user