Switch to unified text encoder for wan models. Pred for 2.2 14b

This commit is contained in:
Jaret Burkett
2025-08-14 10:07:18 -06:00
parent e12bb21780
commit be71cc75ce
5 changed files with 101 additions and 62 deletions

View File

View File

@@ -0,0 +1,15 @@
import os
from typing import List
from toolkit.paths import COMFY_MODELS_PATH
def get_comfy_path(comfy_files: List[str]) -> str:
"""
Get the path to the first existing file in the COMFY_MODELS_PATH.
"""
if COMFY_MODELS_PATH is not None and comfy_files is not None and len(comfy_files) > 0:
for file in comfy_files:
file_path = os.path.join(COMFY_MODELS_PATH, file)
if os.path.exists(file_path):
return file_path
return None

View File

@@ -0,0 +1,32 @@
from typing import List
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from toolkit.models.loaders.comfy import get_comfy_path
def get_umt5_encoder(
model_path: str,
tokenizer_subfolder: str = None,
encoder_subfolder: str = None,
torch_dtype: str = torch.bfloat16,
comfy_files: List[str] = [
"text_encoders/umt5_xxl_fp16.safetensors",
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
],
) -> UMT5EncoderModel:
"""
Load the UMT5 encoder model from the specified path.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
comfy_path = get_comfy_path(comfy_files)
comfy_path = None
if comfy_path is not None:
text_encoder = UMT5EncoderModel.from_single_file(
comfy_path, torch_dtype=torch_dtype
)
else:
print(f"Using {model_path} for UMT5 encoder.")
text_encoder = UMT5EncoderModel.from_pretrained(
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype
)
return tokenizer, text_encoder