mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Switch to unified text encoder for wan models. Pred for 2.2 14b
This commit is contained in:
0
toolkit/models/loaders/__init__.py
Normal file
0
toolkit/models/loaders/__init__.py
Normal file
15
toolkit/models/loaders/comfy.py
Normal file
15
toolkit/models/loaders/comfy.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from toolkit.paths import COMFY_MODELS_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def get_comfy_path(comfy_files: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Get the path to the first existing file in the COMFY_MODELS_PATH.
|
||||||
|
"""
|
||||||
|
if COMFY_MODELS_PATH is not None and comfy_files is not None and len(comfy_files) > 0:
|
||||||
|
for file in comfy_files:
|
||||||
|
file_path = os.path.join(COMFY_MODELS_PATH, file)
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
return file_path
|
||||||
|
return None
|
||||||
32
toolkit/models/loaders/umt5.py
Normal file
32
toolkit/models/loaders/umt5.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from typing import List
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||||
|
from toolkit.models.loaders.comfy import get_comfy_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_umt5_encoder(
|
||||||
|
model_path: str,
|
||||||
|
tokenizer_subfolder: str = None,
|
||||||
|
encoder_subfolder: str = None,
|
||||||
|
torch_dtype: str = torch.bfloat16,
|
||||||
|
comfy_files: List[str] = [
|
||||||
|
"text_encoders/umt5_xxl_fp16.safetensors",
|
||||||
|
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||||
|
],
|
||||||
|
) -> UMT5EncoderModel:
|
||||||
|
"""
|
||||||
|
Load the UMT5 encoder model from the specified path.
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
|
||||||
|
comfy_path = get_comfy_path(comfy_files)
|
||||||
|
comfy_path = None
|
||||||
|
if comfy_path is not None:
|
||||||
|
text_encoder = UMT5EncoderModel.from_single_file(
|
||||||
|
comfy_path, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Using {model_path} for UMT5 encoder.")
|
||||||
|
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||||
|
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
return tokenizer, text_encoder
|
||||||
@@ -42,6 +42,8 @@ from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
|||||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
|
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
|
||||||
|
from toolkit.util.quantize import quantize_model
|
||||||
|
from toolkit.models.loaders.umt5 import get_umt5_encoder
|
||||||
|
|
||||||
# for generation only?
|
# for generation only?
|
||||||
scheduler_configUniPC = {
|
scheduler_configUniPC = {
|
||||||
@@ -308,6 +310,8 @@ class Wan21(BaseModel):
|
|||||||
arch = 'wan21'
|
arch = 'wan21'
|
||||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||||
_wan_expand_timesteps = False
|
_wan_expand_timesteps = False
|
||||||
|
|
||||||
|
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device,
|
device,
|
||||||
@@ -334,27 +338,10 @@ class Wan21(BaseModel):
|
|||||||
def get_train_scheduler():
|
def get_train_scheduler():
|
||||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
def load_model(self):
|
def load_wan_transformer(self, transformer_path, subfolder=None):
|
||||||
dtype = self.torch_dtype
|
|
||||||
model_path = self.model_config.name_or_path
|
|
||||||
|
|
||||||
self.print_and_status_update("Loading Wan model")
|
|
||||||
subfolder = 'transformer'
|
|
||||||
transformer_path = model_path
|
|
||||||
if os.path.exists(transformer_path):
|
|
||||||
subfolder = None
|
|
||||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
|
||||||
|
|
||||||
te_path = self.model_config.extras_name_or_path
|
|
||||||
if os.path.exists(os.path.join(model_path, 'text_encoder')):
|
|
||||||
te_path = model_path
|
|
||||||
|
|
||||||
vae_path = self.model_config.extras_name_or_path
|
|
||||||
if os.path.exists(os.path.join(model_path, 'vae')):
|
|
||||||
vae_path = model_path
|
|
||||||
|
|
||||||
self.print_and_status_update("Loading transformer")
|
self.print_and_status_update("Loading transformer")
|
||||||
|
dtype = self.torch_dtype
|
||||||
transformer = WanTransformer3DModel.from_pretrained(
|
transformer = WanTransformer3DModel.from_pretrained(
|
||||||
transformer_path,
|
transformer_path,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
@@ -379,55 +366,53 @@ class Wan21(BaseModel):
|
|||||||
"Loading LoRA is not supported for Wan2.1 models currently")
|
"Loading LoRA is not supported for Wan2.1 models currently")
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.model_config.quantize:
|
if self.model_config.quantize:
|
||||||
print("Quantizing Transformer")
|
self.print_and_status_update("Quantizing Transformer")
|
||||||
quantization_args = self.model_config.quantize_kwargs
|
quantize_model(self, transformer)
|
||||||
if 'exclude' not in quantization_args:
|
flush()
|
||||||
quantization_args['exclude'] = []
|
|
||||||
# patch the state dict method
|
if self.model_config.low_vram:
|
||||||
patch_dequantization_on_save(transformer)
|
self.print_and_status_update("Moving transformer to CPU")
|
||||||
quantization_type = get_qtype(self.model_config.qtype)
|
transformer.to('cpu')
|
||||||
if self.model_config.low_vram:
|
|
||||||
print("Quantizing blocks")
|
|
||||||
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
|
||||||
# quantize each block
|
|
||||||
idx = 0
|
|
||||||
for block in tqdm(transformer.blocks):
|
|
||||||
block.to(self.device_torch)
|
|
||||||
quantize(block, weights=quantization_type,
|
|
||||||
**quantization_args)
|
|
||||||
freeze(block)
|
|
||||||
idx += 1
|
|
||||||
flush()
|
|
||||||
|
|
||||||
print("Quantizing the rest")
|
return transformer
|
||||||
low_vram_exclude = copy.deepcopy(quantization_args['exclude'])
|
|
||||||
low_vram_exclude.append('blocks.*')
|
|
||||||
quantization_args['exclude'] = low_vram_exclude
|
|
||||||
# quantize the rest
|
|
||||||
transformer.to(self.device_torch)
|
|
||||||
quantize(transformer, weights=quantization_type,
|
|
||||||
**quantization_args)
|
|
||||||
|
|
||||||
quantization_args['exclude'] = orig_exclude
|
def load_model(self):
|
||||||
else:
|
dtype = self.torch_dtype
|
||||||
# do it in one go
|
model_path = self.model_config.name_or_path
|
||||||
quantize(transformer, weights=quantization_type,
|
|
||||||
**quantization_args)
|
self.print_and_status_update("Loading Wan model")
|
||||||
freeze(transformer)
|
subfolder = 'transformer'
|
||||||
# move it to the cpu for now
|
transformer_path = model_path
|
||||||
transformer.to("cpu")
|
if os.path.exists(transformer_path):
|
||||||
else:
|
subfolder = None
|
||||||
transformer.to(self.device_torch, dtype=dtype)
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||||
|
|
||||||
|
te_path = "ai-toolkit/umt5_xxl_encoder"
|
||||||
|
if os.path.exists(os.path.join(model_path, 'text_encoder')):
|
||||||
|
te_path = model_path
|
||||||
|
|
||||||
|
vae_path = self.model_config.extras_name_or_path
|
||||||
|
if os.path.exists(os.path.join(model_path, 'vae')):
|
||||||
|
vae_path = model_path
|
||||||
|
|
||||||
|
transformer = self.load_wan_transformer(
|
||||||
|
transformer_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
self.print_and_status_update("Loading UMT5EncoderModel")
|
self.print_and_status_update("Loading UMT5EncoderModel")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
te_path, subfolder="tokenizer", torch_dtype=dtype)
|
tokenizer, text_encoder = get_umt5_encoder(
|
||||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
model_path=te_path,
|
||||||
te_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype)
|
tokenizer_subfolder="tokenizer",
|
||||||
|
encoder_subfolder="text_encoder",
|
||||||
|
torch_dtype=dtype,
|
||||||
|
comfy_files=self._comfy_te_file
|
||||||
|
)
|
||||||
|
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
@@ -678,3 +663,6 @@ class Wan21(BaseModel):
|
|||||||
|
|
||||||
def get_base_model_version(self):
|
def get_base_model_version(self):
|
||||||
return "wan_2.1"
|
return "wan_2.1"
|
||||||
|
|
||||||
|
def get_transformer_block_names(self):
|
||||||
|
return ['blocks']
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
|||||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||||
|
COMFY_PATH = os.getenv("COMFY_PATH", None)
|
||||||
|
COMFY_MODELS_PATH = None
|
||||||
|
if COMFY_PATH:
|
||||||
|
COMFY_MODELS_PATH = os.path.join(COMFY_PATH, "models")
|
||||||
|
|
||||||
# check if ENV variable is set
|
# check if ENV variable is set
|
||||||
if 'MODELS_PATH' in os.environ:
|
if 'MODELS_PATH' in os.environ:
|
||||||
|
|||||||
Reference in New Issue
Block a user