Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -482,6 +482,8 @@ class TrainConfig:
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
# will make training faster and use less vram
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
# will toggle all datasets to cache text embeddings
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
# for swapping which parameters are trained during training
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
@@ -1189,6 +1191,7 @@ def validate_configs(
train_config: TrainConfig,
model_config: ModelConfig,
save_config: SaveConfig,
dataset_configs: List[DatasetConfig]
):
if model_config.is_flux:
if save_config.save_format != 'diffusers':
@@ -1200,3 +1203,18 @@ def validate_configs(
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
# see if any datasets are caching text embeddings
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
if is_caching_text_embeddings:
# check if they are doing differential output preservation
if train_config.diff_output_preservation:
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
# make sure they are all cached
for dataset in dataset_configs:
if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")

View File

@@ -558,6 +558,8 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk()
if self.is_caching_text_embeddings:
self.cache_text_embeddings()
if self.is_generating_controls:
# always do this last
self.setup_controls()

View File

@@ -13,8 +13,8 @@ from toolkit import image_utils
from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -32,6 +32,7 @@ def print_once(msg):
class FileItemDTO(
LatentCachingFileItemDTOMixin,
TextEmbeddingFileItemDTOMixin,
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
@@ -124,6 +125,7 @@ class FileItemDTO(
def cleanup(self):
self.tensor = None
self.cleanup_latent()
self.cleanup_text_embedding()
self.cleanup_control()
self.cleanup_inpaint()
self.cleanup_clip_image()
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
if any([x.prompt_embeds is not None for x in self.file_items]):
# find one to use as a base
base_prompt_embeds = None
for x in self.file_items:
if x.prompt_embeds is not None:
base_prompt_embeds = x.prompt_embeds
break
prompt_embeds_list = []
for x in self.file_items:
if x.prompt_embeds is None:
prompt_embeds_list.append(base_prompt_embeds)
else:
prompt_embeds_list.append(x.prompt_embeds)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
except Exception as e:
print(e)

View File

@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
self.extra_values: List[float] = dataset_config.extra_values
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
if self.raw_caption is not None:
# we already loaded it
pass
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# handle get_prompt_embedding
if self.is_text_embedding_cached:
self.load_prompt_embedding()
# if we are caching latents, just do that
if self.is_latent_cached:
self.get_latent()
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
self.sd.restore_device_state()
class TextEmbeddingFileItemDTOMixin:
def __init__(self, *args, **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.prompt_embeds: Union[PromptEmbeds, None] = None
self._text_embedding_path: Union[str, None] = None
self.is_text_embedding_cached = False
self.text_embedding_load_device = 'cpu'
self.text_embedding_space_version = 'sd1'
self.text_embedding_version = 1
def get_text_embedding_info_dict(self: 'FileItemDTO'):
# make sure the caption is loaded here
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
if self.caption is None:
self.load_caption()
# throw error is [trigger] in caption as we cannot inject it while caching
if '[trigger]' in self.caption:
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
item = OrderedDict([
("caption", self.caption),
("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version),
])
return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
if self._text_embedding_path is not None and not recalculate:
return self._text_embedding_path
else:
# we store text embeddings in a folder in same path as image called _text_embedding_cache
img_dir = os.path.dirname(self.path)
te_dir = os.path.join(img_dir, '_t_e_cache')
hash_dict = self.get_text_embedding_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._text_embedding_path
def cleanup_text_embedding(self):
if self.prompt_embeds is not None:
# we are caching on disk, don't save in memory
self.prompt_embeds = None
def load_prompt_embedding(self, device=None):
if not self.is_text_embedding_cached:
return
if self.prompt_embeds is None:
# load it from disk
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
class TextEmbeddingCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
if self.is_caching_text_embeddings:
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
def cache_text_embeddings(self: 'AiToolkitDataset'):
with accelerator.main_process_first():
print_acc(f"Caching text_embeddings for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
print_acc(" - Saving text embeddings to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents')
did_move = False
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
file_item.text_embedding_space_version = self.sd.model_config.arch
file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True)
# check if it is saved to disk already
if os.path.exists(latent_path):
if to_memory:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
# only process if not saved to disk
if not os.path.exists(text_embedding_path):
# load if not loaded
if not did_move:
self.sd.set_device_state_preset('cache_text_encoder')
did_move = True
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# save it
prompt_embeds.save(text_embedding_path)
del prompt_embeds
file_item.is_text_embedding_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
# if did_move:
# self.sd.restore_device_state()
class CLIPCachingMixin:

View File

@@ -168,6 +168,8 @@ class BaseModel:
self._after_sample_img_hooks = []
self._status_update_hooks = []
self.is_transformer = False
self.sample_prompts_cache = None
# properties for old arch for backwards compatibility
@property
@@ -484,19 +486,23 @@ class BaseModel:
quad_count=4
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else:
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(

View File

@@ -92,6 +92,56 @@ class PromptEmbeds:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
return pe
def save(self, path: str):
"""
Save the prompt embeds to a file.
:param path: The path to save the prompt embeds.
"""
pe = self.clone()
state_dict = {}
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
for i, text_embed in enumerate(pe.text_embeds):
state_dict[f"text_embed_{i}"] = text_embed.cpu()
else:
state_dict["text_embed"] = pe.text_embeds.cpu()
if pe.pooled_embeds is not None:
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
if pe.attention_mask is not None:
state_dict["attention_mask"] = pe.attention_mask.cpu()
os.makedirs(os.path.dirname(path), exist_ok=True)
save_file(state_dict, path)
@classmethod
def load(cls, path: str) -> 'PromptEmbeds':
"""
Load the prompt embeds from a file.
:param path: The path to load the prompt embeds from.
:return: An instance of PromptEmbeds.
"""
state_dict = load_file(path, device='cpu')
text_embeds = []
pooled_embeds = None
attention_mask = None
for key in sorted(state_dict.keys()):
if key.startswith("text_embed_"):
text_embeds.append(state_dict[key])
elif key == "text_embed":
text_embeds.append(state_dict[key])
elif key == "pooled_embed":
pooled_embeds = state_dict[key]
elif key == "attention_mask":
attention_mask = state_dict[key]
pe = cls(None)
pe.text_embeds = text_embeds
if len(text_embeds) == 1:
pe.text_embeds = text_embeds[0]
if pooled_embeds is not None:
pe.pooled_embeds = pooled_embeds
if attention_mask is not None:
pe.attention_mask = attention_mask
return pe
class EncodedPromptPair:
def __init__(

View File

@@ -209,6 +209,8 @@ class StableDiffusion:
# todo update this based on the model
self.is_transformer = False
self.sample_prompts_cache = None
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -1426,18 +1428,22 @@ class StableDiffusion:
quad_count=4
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else:
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(

63
toolkit/unloader.py Normal file
View File

@@ -0,0 +1,63 @@
import torch
from toolkit.basic import flush
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
class FakeTextEncoder(torch.nn.Module):
def __init__(self, device, dtype):
super().__init__()
# register a dummy parameter to avoid errors in some cases
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
self._device = device
self._dtype = dtype
def forward(self, *args, **kwargs):
raise NotImplementedError(
"This is a fake text encoder and should not be used for inference."
)
return None
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(self, *args, **kwargs):
return self
def unload_text_encoder(model: "BaseModel"):
# unload the text encoder in a way that will work with all models and will not throw errors
# we need to make it appear as a text encoder module without actually having one so all
# to functions and what not will work.
if model.text_encoder is not None:
if isinstance(model.text_encoder, list):
text_encoder_list = []
pipe = model.pipeline
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
if hasattr(pipe, "text_encoder"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
pipe.text_encoder = te
i = 2
while hasattr(pipe, f"text_encoder_{i}"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
setattr(pipe, f"text_encoder_{i}", te)
i += 1
model.text_encoder = text_encoder_list
else:
# only has a single text encoder
model.text_encoder = FakeTextEncoder()
flush()