mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Switch to unified text encoder for wan models. Pred for 2.2 14b
This commit is contained in:
0
toolkit/models/loaders/__init__.py
Normal file
0
toolkit/models/loaders/__init__.py
Normal file
15
toolkit/models/loaders/comfy.py
Normal file
15
toolkit/models/loaders/comfy.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from typing import List
|
||||
from toolkit.paths import COMFY_MODELS_PATH
|
||||
|
||||
|
||||
def get_comfy_path(comfy_files: List[str]) -> str:
|
||||
"""
|
||||
Get the path to the first existing file in the COMFY_MODELS_PATH.
|
||||
"""
|
||||
if COMFY_MODELS_PATH is not None and comfy_files is not None and len(comfy_files) > 0:
|
||||
for file in comfy_files:
|
||||
file_path = os.path.join(COMFY_MODELS_PATH, file)
|
||||
if os.path.exists(file_path):
|
||||
return file_path
|
||||
return None
|
||||
32
toolkit/models/loaders/umt5.py
Normal file
32
toolkit/models/loaders/umt5.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from toolkit.models.loaders.comfy import get_comfy_path
|
||||
|
||||
|
||||
def get_umt5_encoder(
|
||||
model_path: str,
|
||||
tokenizer_subfolder: str = None,
|
||||
encoder_subfolder: str = None,
|
||||
torch_dtype: str = torch.bfloat16,
|
||||
comfy_files: List[str] = [
|
||||
"text_encoders/umt5_xxl_fp16.safetensors",
|
||||
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
],
|
||||
) -> UMT5EncoderModel:
|
||||
"""
|
||||
Load the UMT5 encoder model from the specified path.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
|
||||
comfy_path = get_comfy_path(comfy_files)
|
||||
comfy_path = None
|
||||
if comfy_path is not None:
|
||||
text_encoder = UMT5EncoderModel.from_single_file(
|
||||
comfy_path, torch_dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
print(f"Using {model_path} for UMT5 encoder.")
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype
|
||||
)
|
||||
return tokenizer, text_encoder
|
||||
@@ -42,6 +42,8 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
|
||||
from toolkit.util.quantize import quantize_model
|
||||
from toolkit.models.loaders.umt5 import get_umt5_encoder
|
||||
|
||||
# for generation only?
|
||||
scheduler_configUniPC = {
|
||||
@@ -308,6 +310,8 @@ class Wan21(BaseModel):
|
||||
arch = 'wan21'
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = False
|
||||
|
||||
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
@@ -334,27 +338,10 @@ class Wan21(BaseModel):
|
||||
def get_train_scheduler():
|
||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
return scheduler
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
self.print_and_status_update("Loading Wan model")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
|
||||
te_path = self.model_config.extras_name_or_path
|
||||
if os.path.exists(os.path.join(model_path, 'text_encoder')):
|
||||
te_path = model_path
|
||||
|
||||
vae_path = self.model_config.extras_name_or_path
|
||||
if os.path.exists(os.path.join(model_path, 'vae')):
|
||||
vae_path = model_path
|
||||
|
||||
|
||||
def load_wan_transformer(self, transformer_path, subfolder=None):
|
||||
self.print_and_status_update("Loading transformer")
|
||||
dtype = self.torch_dtype
|
||||
transformer = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
@@ -379,55 +366,53 @@ class Wan21(BaseModel):
|
||||
"Loading LoRA is not supported for Wan2.1 models currently")
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing Transformer")
|
||||
quantization_args = self.model_config.quantize_kwargs
|
||||
if 'exclude' not in quantization_args:
|
||||
quantization_args['exclude'] = []
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
if self.model_config.low_vram:
|
||||
print("Quantizing blocks")
|
||||
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
||||
# quantize each block
|
||||
idx = 0
|
||||
for block in tqdm(transformer.blocks):
|
||||
block.to(self.device_torch)
|
||||
quantize(block, weights=quantization_type,
|
||||
**quantization_args)
|
||||
freeze(block)
|
||||
idx += 1
|
||||
flush()
|
||||
self.print_and_status_update("Quantizing Transformer")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to('cpu')
|
||||
|
||||
print("Quantizing the rest")
|
||||
low_vram_exclude = copy.deepcopy(quantization_args['exclude'])
|
||||
low_vram_exclude.append('blocks.*')
|
||||
quantization_args['exclude'] = low_vram_exclude
|
||||
# quantize the rest
|
||||
transformer.to(self.device_torch)
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**quantization_args)
|
||||
return transformer
|
||||
|
||||
quantization_args['exclude'] = orig_exclude
|
||||
else:
|
||||
# do it in one go
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**quantization_args)
|
||||
freeze(transformer)
|
||||
# move it to the cpu for now
|
||||
transformer.to("cpu")
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
self.print_and_status_update("Loading Wan model")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
|
||||
te_path = "ai-toolkit/umt5_xxl_encoder"
|
||||
if os.path.exists(os.path.join(model_path, 'text_encoder')):
|
||||
te_path = model_path
|
||||
|
||||
vae_path = self.model_config.extras_name_or_path
|
||||
if os.path.exists(os.path.join(model_path, 'vae')):
|
||||
vae_path = model_path
|
||||
|
||||
transformer = self.load_wan_transformer(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading UMT5EncoderModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
te_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
te_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype)
|
||||
|
||||
tokenizer, text_encoder = get_umt5_encoder(
|
||||
model_path=te_path,
|
||||
tokenizer_subfolder="tokenizer",
|
||||
encoder_subfolder="text_encoder",
|
||||
torch_dtype=dtype,
|
||||
comfy_files=self._comfy_te_file
|
||||
)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
@@ -678,3 +663,6 @@ class Wan21(BaseModel):
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.1"
|
||||
|
||||
def get_transformer_block_names(self):
|
||||
return ['blocks']
|
||||
|
||||
@@ -5,6 +5,10 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||
COMFY_PATH = os.getenv("COMFY_PATH", None)
|
||||
COMFY_MODELS_PATH = None
|
||||
if COMFY_PATH:
|
||||
COMFY_MODELS_PATH = os.path.join(COMFY_PATH, "models")
|
||||
|
||||
# check if ENV variable is set
|
||||
if 'MODELS_PATH' in os.environ:
|
||||
|
||||
Reference in New Issue
Block a user