diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index a32183ce..aef30b7d 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -375,7 +375,9 @@ class Wan2214bModel(Wan21): return transformer def get_generation_pipeline(self): - scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() pipeline = Wan22Pipeline( vae=self.vae, transformer=self.model.transformer_1, diff --git a/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py index 9c446cc5..9d29129b 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py @@ -116,7 +116,9 @@ class Wan225bModel(Wan21): return 32 def get_generation_pipeline(self): - scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() pipeline = Wan22Pipeline( vae=self.vae, transformer=self.model, diff --git a/toolkit/models/loaders/umt5.py b/toolkit/models/loaders/umt5.py index fd666269..b3613923 100644 --- a/toolkit/models/loaders/umt5.py +++ b/toolkit/models/loaders/umt5.py @@ -1,8 +1,30 @@ from typing import List import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import T5Tokenizer, UMT5EncoderModel from toolkit.models.loaders.comfy import get_comfy_path +class PatchedT5Tokenizer(T5Tokenizer): + def __init__( + self, + vocab: str | list[tuple[str, float]] | None = None, + eos_token="", + unk_token="", + pad_token="", + _spm_precompiled_charsmap=None, + extra_ids=100, + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab=vocab, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + _spm_precompiled_charsmap=None, # this is passing a empty byte string for some reason now. + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) def get_umt5_encoder( model_path: str, @@ -17,7 +39,7 @@ def get_umt5_encoder( """ Load the UMT5 encoder model from the specified path. """ - tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder) + tokenizer = PatchedT5Tokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder) comfy_path = get_comfy_path(comfy_files) comfy_path = None if comfy_path is not None: diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 4932b51d..7e14bdc4 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -491,7 +491,9 @@ class Wan21(BaseModel): self.tokenizer = tokenizer def get_generation_pipeline(self): - scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() if self.model_config.low_vram: pipeline = AggressiveWanUnloadPipeline( vae=self.vae, diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py index bf5a88b8..fbe012f8 100644 --- a/toolkit/models/wan21/wan21_i2v.py +++ b/toolkit/models/wan21/wan21_i2v.py @@ -369,7 +369,9 @@ class Wan21I2V(Wan21): def get_generation_pipeline(self): - scheduler = UniPCMultistepScheduler(**scheduler_configUniPC) + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() if self.model_config.low_vram: pipeline = AggressiveWanI2VUnloadPipeline( vae=self.vae,