mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Switch to unified text encoder for wan models. Pred for 2.2 14b
This commit is contained in:
32
toolkit/models/loaders/umt5.py
Normal file
32
toolkit/models/loaders/umt5.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from toolkit.models.loaders.comfy import get_comfy_path
|
||||
|
||||
|
||||
def get_umt5_encoder(
|
||||
model_path: str,
|
||||
tokenizer_subfolder: str = None,
|
||||
encoder_subfolder: str = None,
|
||||
torch_dtype: str = torch.bfloat16,
|
||||
comfy_files: List[str] = [
|
||||
"text_encoders/umt5_xxl_fp16.safetensors",
|
||||
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
],
|
||||
) -> UMT5EncoderModel:
|
||||
"""
|
||||
Load the UMT5 encoder model from the specified path.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
|
||||
comfy_path = get_comfy_path(comfy_files)
|
||||
comfy_path = None
|
||||
if comfy_path is not None:
|
||||
text_encoder = UMT5EncoderModel.from_single_file(
|
||||
comfy_path, torch_dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
print(f"Using {model_path} for UMT5 encoder.")
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype
|
||||
)
|
||||
return tokenizer, text_encoder
|
||||
Reference in New Issue
Block a user