Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
self.extra_values: List[float] = dataset_config.extra_values
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
if self.raw_caption is not None:
# we already loaded it
pass
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# handle get_prompt_embedding
if self.is_text_embedding_cached:
self.load_prompt_embedding()
# if we are caching latents, just do that
if self.is_latent_cached:
self.get_latent()
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
self.sd.restore_device_state()
class TextEmbeddingFileItemDTOMixin:
def __init__(self, *args, **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.prompt_embeds: Union[PromptEmbeds, None] = None
self._text_embedding_path: Union[str, None] = None
self.is_text_embedding_cached = False
self.text_embedding_load_device = 'cpu'
self.text_embedding_space_version = 'sd1'
self.text_embedding_version = 1
def get_text_embedding_info_dict(self: 'FileItemDTO'):
# make sure the caption is loaded here
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
if self.caption is None:
self.load_caption()
# throw error is [trigger] in caption as we cannot inject it while caching
if '[trigger]' in self.caption:
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
item = OrderedDict([
("caption", self.caption),
("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version),
])
return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
if self._text_embedding_path is not None and not recalculate:
return self._text_embedding_path
else:
# we store text embeddings in a folder in same path as image called _text_embedding_cache
img_dir = os.path.dirname(self.path)
te_dir = os.path.join(img_dir, '_t_e_cache')
hash_dict = self.get_text_embedding_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._text_embedding_path
def cleanup_text_embedding(self):
if self.prompt_embeds is not None:
# we are caching on disk, don't save in memory
self.prompt_embeds = None
def load_prompt_embedding(self, device=None):
if not self.is_text_embedding_cached:
return
if self.prompt_embeds is None:
# load it from disk
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
class TextEmbeddingCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
if self.is_caching_text_embeddings:
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
def cache_text_embeddings(self: 'AiToolkitDataset'):
with accelerator.main_process_first():
print_acc(f"Caching text_embeddings for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
print_acc(" - Saving text embeddings to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents')
did_move = False
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
file_item.text_embedding_space_version = self.sd.model_config.arch
file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True)
# check if it is saved to disk already
if os.path.exists(latent_path):
if to_memory:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
# only process if not saved to disk
if not os.path.exists(text_embedding_path):
# load if not loaded
if not did_move:
self.sd.set_device_state_preset('cache_text_encoder')
did_move = True
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# save it
prompt_embeds.save(text_embedding_path)
del prompt_embeds
file_item.is_text_embedding_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
# if did_move:
# self.sd.restore_device_state()
class CLIPCachingMixin: