From be71cc75ce805168dd16f7315f283880525d5aa7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 14 Aug 2025 10:07:18 -0600 Subject: [PATCH] Switch to unified text encoder for wan models. Pred for 2.2 14b --- toolkit/models/loaders/__init__.py | 0 toolkit/models/loaders/comfy.py | 15 ++++ toolkit/models/loaders/umt5.py | 32 +++++++++ toolkit/models/wan21/wan21.py | 112 +++++++++++++---------------- toolkit/paths.py | 4 ++ 5 files changed, 101 insertions(+), 62 deletions(-) create mode 100644 toolkit/models/loaders/__init__.py create mode 100644 toolkit/models/loaders/comfy.py create mode 100644 toolkit/models/loaders/umt5.py diff --git a/toolkit/models/loaders/__init__.py b/toolkit/models/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkit/models/loaders/comfy.py b/toolkit/models/loaders/comfy.py new file mode 100644 index 00000000..38a32046 --- /dev/null +++ b/toolkit/models/loaders/comfy.py @@ -0,0 +1,15 @@ +import os +from typing import List +from toolkit.paths import COMFY_MODELS_PATH + + +def get_comfy_path(comfy_files: List[str]) -> str: + """ + Get the path to the first existing file in the COMFY_MODELS_PATH. + """ + if COMFY_MODELS_PATH is not None and comfy_files is not None and len(comfy_files) > 0: + for file in comfy_files: + file_path = os.path.join(COMFY_MODELS_PATH, file) + if os.path.exists(file_path): + return file_path + return None \ No newline at end of file diff --git a/toolkit/models/loaders/umt5.py b/toolkit/models/loaders/umt5.py new file mode 100644 index 00000000..fd666269 --- /dev/null +++ b/toolkit/models/loaders/umt5.py @@ -0,0 +1,32 @@ +from typing import List +import torch +from transformers import AutoTokenizer, UMT5EncoderModel +from toolkit.models.loaders.comfy import get_comfy_path + + +def get_umt5_encoder( + model_path: str, + tokenizer_subfolder: str = None, + encoder_subfolder: str = None, + torch_dtype: str = torch.bfloat16, + comfy_files: List[str] = [ + "text_encoders/umt5_xxl_fp16.safetensors", + "text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", + ], +) -> UMT5EncoderModel: + """ + Load the UMT5 encoder model from the specified path. + """ + tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder) + comfy_path = get_comfy_path(comfy_files) + comfy_path = None + if comfy_path is not None: + text_encoder = UMT5EncoderModel.from_single_file( + comfy_path, torch_dtype=torch_dtype + ) + else: + print(f"Using {model_path} for UMT5 encoder.") + text_encoder = UMT5EncoderModel.from_pretrained( + model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype + ) + return tokenizer, text_encoder diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 7146630d..bdc2f601 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -42,6 +42,8 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from typing import Any, Callable, Dict, List, Optional, Union from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original +from toolkit.util.quantize import quantize_model +from toolkit.models.loaders.umt5 import get_umt5_encoder # for generation only? scheduler_configUniPC = { @@ -308,6 +310,8 @@ class Wan21(BaseModel): arch = 'wan21' _wan_generation_scheduler_config = scheduler_configUniPC _wan_expand_timesteps = False + + _comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors'] def __init__( self, device, @@ -334,27 +338,10 @@ class Wan21(BaseModel): def get_train_scheduler(): scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) return scheduler - - def load_model(self): - dtype = self.torch_dtype - model_path = self.model_config.name_or_path - - self.print_and_status_update("Loading Wan model") - subfolder = 'transformer' - transformer_path = model_path - if os.path.exists(transformer_path): - subfolder = None - transformer_path = os.path.join(transformer_path, 'transformer') - - te_path = self.model_config.extras_name_or_path - if os.path.exists(os.path.join(model_path, 'text_encoder')): - te_path = model_path - - vae_path = self.model_config.extras_name_or_path - if os.path.exists(os.path.join(model_path, 'vae')): - vae_path = model_path - + + def load_wan_transformer(self, transformer_path, subfolder=None): self.print_and_status_update("Loading transformer") + dtype = self.torch_dtype transformer = WanTransformer3DModel.from_pretrained( transformer_path, subfolder=subfolder, @@ -379,55 +366,53 @@ class Wan21(BaseModel): "Loading LoRA is not supported for Wan2.1 models currently") flush() - + if self.model_config.quantize: - print("Quantizing Transformer") - quantization_args = self.model_config.quantize_kwargs - if 'exclude' not in quantization_args: - quantization_args['exclude'] = [] - # patch the state dict method - patch_dequantization_on_save(transformer) - quantization_type = get_qtype(self.model_config.qtype) - if self.model_config.low_vram: - print("Quantizing blocks") - orig_exclude = copy.deepcopy(quantization_args['exclude']) - # quantize each block - idx = 0 - for block in tqdm(transformer.blocks): - block.to(self.device_torch) - quantize(block, weights=quantization_type, - **quantization_args) - freeze(block) - idx += 1 - flush() + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to('cpu') - print("Quantizing the rest") - low_vram_exclude = copy.deepcopy(quantization_args['exclude']) - low_vram_exclude.append('blocks.*') - quantization_args['exclude'] = low_vram_exclude - # quantize the rest - transformer.to(self.device_torch) - quantize(transformer, weights=quantization_type, - **quantization_args) + return transformer - quantization_args['exclude'] = orig_exclude - else: - # do it in one go - quantize(transformer, weights=quantization_type, - **quantization_args) - freeze(transformer) - # move it to the cpu for now - transformer.to("cpu") - else: - transformer.to(self.device_torch, dtype=dtype) + def load_model(self): + dtype = self.torch_dtype + model_path = self.model_config.name_or_path + + self.print_and_status_update("Loading Wan model") + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + + te_path = "ai-toolkit/umt5_xxl_encoder" + if os.path.exists(os.path.join(model_path, 'text_encoder')): + te_path = model_path + + vae_path = self.model_config.extras_name_or_path + if os.path.exists(os.path.join(model_path, 'vae')): + vae_path = model_path + + transformer = self.load_wan_transformer( + transformer_path, + subfolder=subfolder, + ) flush() self.print_and_status_update("Loading UMT5EncoderModel") - tokenizer = AutoTokenizer.from_pretrained( - te_path, subfolder="tokenizer", torch_dtype=dtype) - text_encoder = UMT5EncoderModel.from_pretrained( - te_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype) + + tokenizer, text_encoder = get_umt5_encoder( + model_path=te_path, + tokenizer_subfolder="tokenizer", + encoder_subfolder="text_encoder", + torch_dtype=dtype, + comfy_files=self._comfy_te_file + ) text_encoder.to(self.device_torch, dtype=dtype) flush() @@ -678,3 +663,6 @@ class Wan21(BaseModel): def get_base_model_version(self): return "wan_2.1" + + def get_transformer_block_names(self): + return ['blocks'] diff --git a/toolkit/paths.py b/toolkit/paths.py index 4b2376d6..edd36ce1 100644 --- a/toolkit/paths.py +++ b/toolkit/paths.py @@ -5,6 +5,10 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs") DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs") +COMFY_PATH = os.getenv("COMFY_PATH", None) +COMFY_MODELS_PATH = None +if COMFY_PATH: + COMFY_MODELS_PATH = os.path.join(COMFY_PATH, "models") # check if ENV variable is set if 'MODELS_PATH' in os.environ: