Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -168,7 +168,9 @@ class QwenImageModel(BaseModel):
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
pipe.transformer = pipe.transformer.to(self.device_torch)
# leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
@@ -210,6 +212,7 @@ class QwenImageModel(BaseModel):
generator: torch.Generator,
extra: dict,
):
self.model.to(self.device_torch, dtype=self.torch_dtype)
control_img = None
if gen_config.ctrl_img is not None:
raise NotImplementedError(

View File

@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, ConcatDataset
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.config_modules import GenerateImageConfig
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder
def flush():
@@ -108,6 +109,33 @@ class SDTrainer(BaseSDTrainProcess):
def before_model_load(self):
pass
def cache_sample_prompts(self):
if self.train_config.disable_sampling:
return
if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0:
# cache all the samples
self.sd.sample_prompts_cache = []
sample_folder = os.path.join(self.save_root, 'samples')
output_path = os.path.join(sample_folder, 'test.jpg')
for i in range(len(self.sample_config.prompts)):
sample_item = self.sample_config.samples[i]
prompt = self.sample_config.prompts[i]
# needed so we can autoparse the prompt to handle flags
gen_img_config = GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
negative_prompt=sample_item.neg,
output_path=output_path,
)
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
self.sd.sample_prompts_cache.append({
'conditional': positive,
'unconditional': negative
})
def before_dataset_load(self):
self.assistant_adapter = None
@@ -143,6 +171,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.is_caching_text_embeddings:
# make sure model is on cpu for this part so we don't oom.
self.sd.unet.to('cpu')
# cache unconditional embeds (blank prompt)
with torch.no_grad():
@@ -195,15 +226,18 @@ class SDTrainer(BaseSDTrainProcess):
self.negative_prompt_pool = [self.train_config.negative_prompt]
# handle unload text encoder
if self.train_config.unload_text_encoder:
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.no_grad():
if self.train_config.train_text_encoder:
raise ValueError("Cannot unload text encoder if training text encoder")
# cache embeddings
print_acc("\n***** UNLOADING TEXT ENCODER *****")
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
if self.is_caching_text_embeddings:
print_acc("Embeddings cached to disk. We dont need the text encoder anymore")
else:
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
print_acc("***********************************")
print_acc("")
self.sd.text_encoder_to(self.device_torch)
@@ -212,9 +246,16 @@ class SDTrainer(BaseSDTrainProcess):
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
self.cache_sample_prompts()
# move back to cpu
self.sd.text_encoder_to('cpu')
# unload the text encoder
if self.is_caching_text_embeddings:
unload_text_encoder(self.sd)
else:
# todo once every model is tested to work, unload properly. Though, this will all be merged into one thing.
# keep legacy usage for now.
self.sd.text_encoder_to("cpu")
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
@@ -923,11 +964,14 @@ class SDTrainer(BaseSDTrainProcess):
prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
if batch.prompt_embeds is not None:
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
else:
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this
# self.network.multiplier = 0.0
@@ -1294,18 +1338,24 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
unconditional_embeds = None
if self.train_config.unload_text_encoder:
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.set_grad_enabled(False):
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
if batch.prompt_embeds is not None:
# use the cached embeds
conditional_embeds = batch.prompt_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
conditional_embeds = concat_prompt_embeds(
[embeds_to_use] * noisy_latents.shape[0]
)
else:
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
conditional_embeds = concat_prompt_embeds(
[embeds_to_use] * noisy_latents.shape[0]
)
if self.train_config.do_cfg:
unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype

View File

@@ -145,7 +145,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.dataset_configs: List[DatasetConfig] = []
self.params = []
# add dataset text embedding cache to their config
if self.train_config.cache_text_embeddings:
for raw_dataset in raw_datasets:
raw_dataset['cache_text_embeddings'] = True
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
@@ -160,6 +167,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.dataset_configs.append(dataset)
self.is_caching_text_embeddings = any(
dataset.cache_text_embeddings for dataset in self.dataset_configs
)
# cannot train trigger word if caching text embeddings
if self.is_caching_text_embeddings and self.trigger_word is not None:
raise ValueError("Cannot train trigger word if caching text embeddings. Please remove the trigger word or disable text embedding caching.")
self.embed_config = None
embedding_raw = self.get_conf('embedding', None)
@@ -206,7 +222,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
require_grads=False # we ensure them later
)
@@ -220,7 +236,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
require_grads=True # We check for grads when getting params
)
@@ -235,7 +251,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos: Union[LearnableSNRGamma, None] = None
self.ema: ExponentialMovingAverage = None
validate_configs(self.train_config, self.model_config, self.save_config)
validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs)
do_profiler = self.get_conf('torch_profiler', False)
self.torch_profiler = None if not do_profiler else torch.profiler.profile(

View File

@@ -482,6 +482,8 @@ class TrainConfig:
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
# will make training faster and use less vram
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
# will toggle all datasets to cache text embeddings
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
# for swapping which parameters are trained during training
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
@@ -1189,6 +1191,7 @@ def validate_configs(
train_config: TrainConfig,
model_config: ModelConfig,
save_config: SaveConfig,
dataset_configs: List[DatasetConfig]
):
if model_config.is_flux:
if save_config.save_format != 'diffusers':
@@ -1200,3 +1203,18 @@ def validate_configs(
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
# see if any datasets are caching text embeddings
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
if is_caching_text_embeddings:
# check if they are doing differential output preservation
if train_config.diff_output_preservation:
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
# make sure they are all cached
for dataset in dataset_configs:
if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")

View File

@@ -558,6 +558,8 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk()
if self.is_caching_text_embeddings:
self.cache_text_embeddings()
if self.is_generating_controls:
# always do this last
self.setup_controls()

View File

@@ -13,8 +13,8 @@ from toolkit import image_utils
from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -32,6 +32,7 @@ def print_once(msg):
class FileItemDTO(
LatentCachingFileItemDTOMixin,
TextEmbeddingFileItemDTOMixin,
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
@@ -124,6 +125,7 @@ class FileItemDTO(
def cleanup(self):
self.tensor = None
self.cleanup_latent()
self.cleanup_text_embedding()
self.cleanup_control()
self.cleanup_inpaint()
self.cleanup_clip_image()
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
if any([x.prompt_embeds is not None for x in self.file_items]):
# find one to use as a base
base_prompt_embeds = None
for x in self.file_items:
if x.prompt_embeds is not None:
base_prompt_embeds = x.prompt_embeds
break
prompt_embeds_list = []
for x in self.file_items:
if x.prompt_embeds is None:
prompt_embeds_list.append(base_prompt_embeds)
else:
prompt_embeds_list.append(x.prompt_embeds)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
except Exception as e:
print(e)

View File

@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
self.extra_values: List[float] = dataset_config.extra_values
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
if self.raw_caption is not None:
# we already loaded it
pass
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# handle get_prompt_embedding
if self.is_text_embedding_cached:
self.load_prompt_embedding()
# if we are caching latents, just do that
if self.is_latent_cached:
self.get_latent()
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
self.sd.restore_device_state()
class TextEmbeddingFileItemDTOMixin:
def __init__(self, *args, **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.prompt_embeds: Union[PromptEmbeds, None] = None
self._text_embedding_path: Union[str, None] = None
self.is_text_embedding_cached = False
self.text_embedding_load_device = 'cpu'
self.text_embedding_space_version = 'sd1'
self.text_embedding_version = 1
def get_text_embedding_info_dict(self: 'FileItemDTO'):
# make sure the caption is loaded here
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
if self.caption is None:
self.load_caption()
# throw error is [trigger] in caption as we cannot inject it while caching
if '[trigger]' in self.caption:
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
item = OrderedDict([
("caption", self.caption),
("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version),
])
return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
if self._text_embedding_path is not None and not recalculate:
return self._text_embedding_path
else:
# we store text embeddings in a folder in same path as image called _text_embedding_cache
img_dir = os.path.dirname(self.path)
te_dir = os.path.join(img_dir, '_t_e_cache')
hash_dict = self.get_text_embedding_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._text_embedding_path
def cleanup_text_embedding(self):
if self.prompt_embeds is not None:
# we are caching on disk, don't save in memory
self.prompt_embeds = None
def load_prompt_embedding(self, device=None):
if not self.is_text_embedding_cached:
return
if self.prompt_embeds is None:
# load it from disk
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
class TextEmbeddingCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
if self.is_caching_text_embeddings:
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
def cache_text_embeddings(self: 'AiToolkitDataset'):
with accelerator.main_process_first():
print_acc(f"Caching text_embeddings for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
print_acc(" - Saving text embeddings to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents')
did_move = False
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
file_item.text_embedding_space_version = self.sd.model_config.arch
file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True)
# check if it is saved to disk already
if os.path.exists(latent_path):
if to_memory:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
# only process if not saved to disk
if not os.path.exists(text_embedding_path):
# load if not loaded
if not did_move:
self.sd.set_device_state_preset('cache_text_encoder')
did_move = True
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# save it
prompt_embeds.save(text_embedding_path)
del prompt_embeds
file_item.is_text_embedding_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
# if did_move:
# self.sd.restore_device_state()
class CLIPCachingMixin:

View File

@@ -168,6 +168,8 @@ class BaseModel:
self._after_sample_img_hooks = []
self._status_update_hooks = []
self.is_transformer = False
self.sample_prompts_cache = None
# properties for old arch for backwards compatibility
@property
@@ -484,19 +486,23 @@ class BaseModel:
quad_count=4
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else:
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(

View File

@@ -92,6 +92,56 @@ class PromptEmbeds:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
return pe
def save(self, path: str):
"""
Save the prompt embeds to a file.
:param path: The path to save the prompt embeds.
"""
pe = self.clone()
state_dict = {}
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
for i, text_embed in enumerate(pe.text_embeds):
state_dict[f"text_embed_{i}"] = text_embed.cpu()
else:
state_dict["text_embed"] = pe.text_embeds.cpu()
if pe.pooled_embeds is not None:
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
if pe.attention_mask is not None:
state_dict["attention_mask"] = pe.attention_mask.cpu()
os.makedirs(os.path.dirname(path), exist_ok=True)
save_file(state_dict, path)
@classmethod
def load(cls, path: str) -> 'PromptEmbeds':
"""
Load the prompt embeds from a file.
:param path: The path to load the prompt embeds from.
:return: An instance of PromptEmbeds.
"""
state_dict = load_file(path, device='cpu')
text_embeds = []
pooled_embeds = None
attention_mask = None
for key in sorted(state_dict.keys()):
if key.startswith("text_embed_"):
text_embeds.append(state_dict[key])
elif key == "text_embed":
text_embeds.append(state_dict[key])
elif key == "pooled_embed":
pooled_embeds = state_dict[key]
elif key == "attention_mask":
attention_mask = state_dict[key]
pe = cls(None)
pe.text_embeds = text_embeds
if len(text_embeds) == 1:
pe.text_embeds = text_embeds[0]
if pooled_embeds is not None:
pe.pooled_embeds = pooled_embeds
if attention_mask is not None:
pe.attention_mask = attention_mask
return pe
class EncodedPromptPair:
def __init__(

View File

@@ -209,6 +209,8 @@ class StableDiffusion:
# todo update this based on the model
self.is_transformer = False
self.sample_prompts_cache = None
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -1426,18 +1428,22 @@ class StableDiffusion:
quad_count=4
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else:
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(

63
toolkit/unloader.py Normal file
View File

@@ -0,0 +1,63 @@
import torch
from toolkit.basic import flush
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
class FakeTextEncoder(torch.nn.Module):
def __init__(self, device, dtype):
super().__init__()
# register a dummy parameter to avoid errors in some cases
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
self._device = device
self._dtype = dtype
def forward(self, *args, **kwargs):
raise NotImplementedError(
"This is a fake text encoder and should not be used for inference."
)
return None
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(self, *args, **kwargs):
return self
def unload_text_encoder(model: "BaseModel"):
# unload the text encoder in a way that will work with all models and will not throw errors
# we need to make it appear as a text encoder module without actually having one so all
# to functions and what not will work.
if model.text_encoder is not None:
if isinstance(model.text_encoder, list):
text_encoder_list = []
pipe = model.pipeline
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
if hasattr(pipe, "text_encoder"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
pipe.text_encoder = te
i = 2
while hasattr(pipe, f"text_encoder_{i}"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
setattr(pipe, f"text_encoder_{i}", te)
i += 1
model.text_encoder = text_encoder_list
else:
# only has a single text encoder
model.text_encoder = FakeTextEncoder()
flush()

View File

@@ -389,22 +389,40 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
/>
</FormGroup>
<NumberInput
label="EMA Decay"
className="pt-2"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
placeholder="eg. 0.99"
min={0}
/>
<FormGroup label="Unload Text Encoder" className="pt-2">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Unload TE"
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
/>
</div>
{jobConfig.config.process[0].train.ema_config?.use_ema && (
<NumberInput
label="EMA Decay"
className="pt-2"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
placeholder="eg. 0.99"
min={0}
/>
)}
<FormGroup label="Text Encoder Optimizations" className="pt-2">
<Checkbox
label="Unload TE"
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
docKey={'train.unload_text_encoder'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
if (value) {
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
}
}}
/>
<Checkbox
label="Cache Text Embeddings"
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
docKey={'train.cache_text_embeddings'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
if (value) {
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
}
}}
/>
</FormGroup>
</div>
<div>
@@ -416,21 +434,27 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
<NumberInput
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
{jobConfig.config.process[0].train.diff_output_preservation && (
<>
<NumberInput
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
</>
)}
</div>
</div>
</Card>
@@ -524,16 +548,14 @@ export default function SimpleJob({
checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/>
{
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
<Checkbox
label="Do I2V"
checked={dataset.do_i2v || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
docKey="datasets.do_i2v"
/>
)
}
{modelArch?.additionalSections?.includes('datasets.do_i2v') && (
<Checkbox
label="Do I2V"
checked={dataset.do_i2v || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
docKey="datasets.do_i2v"
/>
)}
</FormGroup>
</div>
<div>

View File

@@ -66,6 +66,7 @@ export const defaultJobConfig: JobConfig = {
weight_decay: 1e-4,
},
unload_text_encoder: false,
cache_text_embeddings: false,
lr: 0.0001,
ema_config: {
use_ema: false,

View File

@@ -12,12 +12,12 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'gpuids': {
gpuids: {
title: 'GPU ID',
description: (
<>
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
However, you can start multiple jobs in parallel, each using a different GPU.
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
However, you can start multiple jobs in parallel, each using a different GPU.
</>
),
},
@@ -25,17 +25,19 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Trigger Word',
description: (
<>
Optional: This will be the word or token used to trigger your concept or character.
<br />
<br />
When using a trigger word,
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots,
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word.
<br />
<br />
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
Optional: This will be the word or token used to trigger your concept or character.
<br />
<br />
When using a trigger word, If your captions do not contain the trigger word, it will be added automatically the
beginning of the caption. If you do not have captions, the caption will become just the trigger word. If you
want to have variable trigger words in your captions to put it in different spots, you can use the{' '}
<code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger
word.
<br />
<br />
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger
word manually or use the
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
</>
),
},
@@ -43,8 +45,9 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Name or Path',
description: (
<>
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here.
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The
folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the
path to an all in one safetensors checkpoint here.
</>
),
},
@@ -52,8 +55,8 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Control Dataset',
description: (
<>
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs.
These images are fed as control/input images during training.
The control dataset needs to have files that match the filenames of your training dataset. They should be
matching file pairs. These images are fed as control/input images during training.
</>
),
},
@@ -61,16 +64,19 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Number of Frames',
description: (
<>
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
<br/>
<br/>
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
So you would want all of your videos trimmed to around 5 seconds for best results.
<br/>
<br/>
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1
for one frame. If your dataset is only videos, frames will be extracted evenly spaced from the videos in the
dataset.
<br />
<br />
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames
will result in a 5 second video. So you would want all of your videos trimmed to around 5 seconds for best
results.
<br />
<br />
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long,
will result in 81 evenly spaced frames for each video making the 2 second video appear slow and the 90second
video appear very fast.
</>
),
},
@@ -78,9 +84,30 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Do I2V',
description: (
<>
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this
dataset to be trained as an I2V dataset. This means that the first frame will be extracted from the video and
used as the start image for the video. If this option is not set, the dataset will be treated as a T2V dataset.
</>
),
},
'train.unload_text_encoder': {
title: 'Unload Text Encoder',
description: (
<>
Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the
GPU. Captions in for the dataset will be ignored
</>
),
},
'train.cache_text_embeddings': {
title: 'Cache Text Embeddings',
description: (
<>
<small>(experimental)</small>
<br />
Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The
text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt
such as trigger words, caption dropout, etc.
</>
),
},

View File

@@ -110,6 +110,7 @@ export interface TrainConfig {
ema_config?: EMAConfig;
dtype: string;
unload_text_encoder: boolean;
cache_text_embeddings: boolean;
optimizer_params: {
weight_decay: number;
};

View File

@@ -1 +1 @@
VERSION = "0.3.18"
VERSION = "0.4.0"