Switch to unified text encoder for wan models. Pred for 2.2 14b

This commit is contained in:
Jaret Burkett
2025-08-14 10:07:18 -06:00
parent e12bb21780
commit be71cc75ce
5 changed files with 101 additions and 62 deletions

View File

View File

@@ -0,0 +1,15 @@
import os
from typing import List
from toolkit.paths import COMFY_MODELS_PATH
def get_comfy_path(comfy_files: List[str]) -> str:
"""
Get the path to the first existing file in the COMFY_MODELS_PATH.
"""
if COMFY_MODELS_PATH is not None and comfy_files is not None and len(comfy_files) > 0:
for file in comfy_files:
file_path = os.path.join(COMFY_MODELS_PATH, file)
if os.path.exists(file_path):
return file_path
return None

View File

@@ -0,0 +1,32 @@
from typing import List
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from toolkit.models.loaders.comfy import get_comfy_path
def get_umt5_encoder(
model_path: str,
tokenizer_subfolder: str = None,
encoder_subfolder: str = None,
torch_dtype: str = torch.bfloat16,
comfy_files: List[str] = [
"text_encoders/umt5_xxl_fp16.safetensors",
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
],
) -> UMT5EncoderModel:
"""
Load the UMT5 encoder model from the specified path.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
comfy_path = get_comfy_path(comfy_files)
comfy_path = None
if comfy_path is not None:
text_encoder = UMT5EncoderModel.from_single_file(
comfy_path, torch_dtype=torch_dtype
)
else:
print(f"Using {model_path} for UMT5 encoder.")
text_encoder = UMT5EncoderModel.from_pretrained(
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype
)
return tokenizer, text_encoder

View File

@@ -42,6 +42,8 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
from toolkit.util.quantize import quantize_model
from toolkit.models.loaders.umt5 import get_umt5_encoder
# for generation only? # for generation only?
scheduler_configUniPC = { scheduler_configUniPC = {
@@ -308,6 +310,8 @@ class Wan21(BaseModel):
arch = 'wan21' arch = 'wan21'
_wan_generation_scheduler_config = scheduler_configUniPC _wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = False _wan_expand_timesteps = False
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
def __init__( def __init__(
self, self,
device, device,
@@ -334,27 +338,10 @@ class Wan21(BaseModel):
def get_train_scheduler(): def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler return scheduler
def load_model(self): def load_wan_transformer(self, transformer_path, subfolder=None):
dtype = self.torch_dtype
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading Wan model")
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
te_path = self.model_config.extras_name_or_path
if os.path.exists(os.path.join(model_path, 'text_encoder')):
te_path = model_path
vae_path = self.model_config.extras_name_or_path
if os.path.exists(os.path.join(model_path, 'vae')):
vae_path = model_path
self.print_and_status_update("Loading transformer") self.print_and_status_update("Loading transformer")
dtype = self.torch_dtype
transformer = WanTransformer3DModel.from_pretrained( transformer = WanTransformer3DModel.from_pretrained(
transformer_path, transformer_path,
subfolder=subfolder, subfolder=subfolder,
@@ -379,55 +366,53 @@ class Wan21(BaseModel):
"Loading LoRA is not supported for Wan2.1 models currently") "Loading LoRA is not supported for Wan2.1 models currently")
flush() flush()
if self.model_config.quantize: if self.model_config.quantize:
print("Quantizing Transformer") self.print_and_status_update("Quantizing Transformer")
quantization_args = self.model_config.quantize_kwargs quantize_model(self, transformer)
if 'exclude' not in quantization_args: flush()
quantization_args['exclude'] = []
# patch the state dict method if self.model_config.low_vram:
patch_dequantization_on_save(transformer) self.print_and_status_update("Moving transformer to CPU")
quantization_type = get_qtype(self.model_config.qtype) transformer.to('cpu')
if self.model_config.low_vram:
print("Quantizing blocks")
orig_exclude = copy.deepcopy(quantization_args['exclude'])
# quantize each block
idx = 0
for block in tqdm(transformer.blocks):
block.to(self.device_torch)
quantize(block, weights=quantization_type,
**quantization_args)
freeze(block)
idx += 1
flush()
print("Quantizing the rest") return transformer
low_vram_exclude = copy.deepcopy(quantization_args['exclude'])
low_vram_exclude.append('blocks.*')
quantization_args['exclude'] = low_vram_exclude
# quantize the rest
transformer.to(self.device_torch)
quantize(transformer, weights=quantization_type,
**quantization_args)
quantization_args['exclude'] = orig_exclude def load_model(self):
else: dtype = self.torch_dtype
# do it in one go model_path = self.model_config.name_or_path
quantize(transformer, weights=quantization_type,
**quantization_args) self.print_and_status_update("Loading Wan model")
freeze(transformer) subfolder = 'transformer'
# move it to the cpu for now transformer_path = model_path
transformer.to("cpu") if os.path.exists(transformer_path):
else: subfolder = None
transformer.to(self.device_torch, dtype=dtype) transformer_path = os.path.join(transformer_path, 'transformer')
te_path = "ai-toolkit/umt5_xxl_encoder"
if os.path.exists(os.path.join(model_path, 'text_encoder')):
te_path = model_path
vae_path = self.model_config.extras_name_or_path
if os.path.exists(os.path.join(model_path, 'vae')):
vae_path = model_path
transformer = self.load_wan_transformer(
transformer_path,
subfolder=subfolder,
)
flush() flush()
self.print_and_status_update("Loading UMT5EncoderModel") self.print_and_status_update("Loading UMT5EncoderModel")
tokenizer = AutoTokenizer.from_pretrained(
te_path, subfolder="tokenizer", torch_dtype=dtype) tokenizer, text_encoder = get_umt5_encoder(
text_encoder = UMT5EncoderModel.from_pretrained( model_path=te_path,
te_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype) tokenizer_subfolder="tokenizer",
encoder_subfolder="text_encoder",
torch_dtype=dtype,
comfy_files=self._comfy_te_file
)
text_encoder.to(self.device_torch, dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype)
flush() flush()
@@ -678,3 +663,6 @@ class Wan21(BaseModel):
def get_base_model_version(self): def get_base_model_version(self):
return "wan_2.1" return "wan_2.1"
def get_transformer_block_names(self):
return ['blocks']

View File

@@ -5,6 +5,10 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs") ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs") DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
COMFY_PATH = os.getenv("COMFY_PATH", None)
COMFY_MODELS_PATH = None
if COMFY_PATH:
COMFY_MODELS_PATH = os.path.join(COMFY_PATH, "models")
# check if ENV variable is set # check if ENV variable is set
if 'MODELS_PATH' in os.environ: if 'MODELS_PATH' in os.environ: