mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 16:30:40 +00:00
Hotfix some issues with Wan models caused by diffusers and transformers updates
This commit is contained in:
@@ -375,7 +375,9 @@ class Wan2214bModel(Wan21):
|
||||
return transformer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
# todo unipc got broken in a diffusers update. Use euler for now.
|
||||
# scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
scheduler = self.get_train_scheduler()
|
||||
pipeline = Wan22Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model.transformer_1,
|
||||
|
||||
@@ -116,7 +116,9 @@ class Wan225bModel(Wan21):
|
||||
return 32
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
# todo unipc got broken in a diffusers update. Use euler for now.
|
||||
# scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
scheduler = self.get_train_scheduler()
|
||||
pipeline = Wan22Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model,
|
||||
|
||||
@@ -1,8 +1,30 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from transformers import T5Tokenizer, UMT5EncoderModel
|
||||
from toolkit.models.loaders.comfy import get_comfy_path
|
||||
|
||||
class PatchedT5Tokenizer(T5Tokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
vocab: str | list[tuple[str, float]] | None = None,
|
||||
eos_token="</s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
_spm_precompiled_charsmap=None,
|
||||
extra_ids=100,
|
||||
additional_special_tokens=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab=vocab,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
_spm_precompiled_charsmap=None, # this is passing a empty byte string for some reason now.
|
||||
extra_ids=extra_ids,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_umt5_encoder(
|
||||
model_path: str,
|
||||
@@ -17,7 +39,7 @@ def get_umt5_encoder(
|
||||
"""
|
||||
Load the UMT5 encoder model from the specified path.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
|
||||
tokenizer = PatchedT5Tokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
|
||||
comfy_path = get_comfy_path(comfy_files)
|
||||
comfy_path = None
|
||||
if comfy_path is not None:
|
||||
|
||||
@@ -491,7 +491,9 @@ class Wan21(BaseModel):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
# todo unipc got broken in a diffusers update. Use euler for now.
|
||||
# scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
scheduler = self.get_train_scheduler()
|
||||
if self.model_config.low_vram:
|
||||
pipeline = AggressiveWanUnloadPipeline(
|
||||
vae=self.vae,
|
||||
|
||||
@@ -369,7 +369,9 @@ class Wan21I2V(Wan21):
|
||||
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
|
||||
# todo unipc got broken in a diffusers update. Use euler for now.
|
||||
# scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
scheduler = self.get_train_scheduler()
|
||||
if self.model_config.low_vram:
|
||||
pipeline = AggressiveWanI2VUnloadPipeline(
|
||||
vae=self.vae,
|
||||
|
||||
Reference in New Issue
Block a user