mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
|
||||
import albumentations as A
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
|
||||
self.extra_values: List[float] = dataset_config.extra_values
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
|
||||
if self.raw_caption is not None:
|
||||
# we already loaded it
|
||||
pass
|
||||
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
|
||||
if self.dataset_config.num_frames > 1:
|
||||
self.load_and_process_video(transform, only_load_latents)
|
||||
return
|
||||
# handle get_prompt_embedding
|
||||
if self.is_text_embedding_cached:
|
||||
self.load_prompt_embedding()
|
||||
# if we are caching latents, just do that
|
||||
if self.is_latent_cached:
|
||||
self.get_latent()
|
||||
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class TextEmbeddingFileItemDTOMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
self._text_embedding_path: Union[str, None] = None
|
||||
self.is_text_embedding_cached = False
|
||||
self.text_embedding_load_device = 'cpu'
|
||||
self.text_embedding_space_version = 'sd1'
|
||||
self.text_embedding_version = 1
|
||||
|
||||
def get_text_embedding_info_dict(self: 'FileItemDTO'):
|
||||
# make sure the caption is loaded here
|
||||
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
|
||||
if self.caption is None:
|
||||
self.load_caption()
|
||||
# throw error is [trigger] in caption as we cannot inject it while caching
|
||||
if '[trigger]' in self.caption:
|
||||
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
|
||||
item = OrderedDict([
|
||||
("caption", self.caption),
|
||||
("text_embedding_space_version", self.text_embedding_space_version),
|
||||
("text_embedding_version", self.text_embedding_version),
|
||||
])
|
||||
return item
|
||||
|
||||
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._text_embedding_path is not None and not recalculate:
|
||||
return self._text_embedding_path
|
||||
else:
|
||||
# we store text embeddings in a folder in same path as image called _text_embedding_cache
|
||||
img_dir = os.path.dirname(self.path)
|
||||
te_dir = os.path.join(img_dir, '_t_e_cache')
|
||||
hash_dict = self.get_text_embedding_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._text_embedding_path
|
||||
|
||||
def cleanup_text_embedding(self):
|
||||
if self.prompt_embeds is not None:
|
||||
# we are caching on disk, don't save in memory
|
||||
self.prompt_embeds = None
|
||||
|
||||
def load_prompt_embedding(self, device=None):
|
||||
if not self.is_text_embedding_cached:
|
||||
return
|
||||
if self.prompt_embeds is None:
|
||||
# load it from disk
|
||||
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
|
||||
|
||||
class TextEmbeddingCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
||||
if self.is_caching_text_embeddings:
|
||||
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
|
||||
|
||||
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
||||
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching text_embeddings for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
print_acc(" - Saving text embeddings to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
did_move = False
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = self.sd.model_config.arch
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
|
||||
file_item.text_embedding_space_version = self.sd.model_config.arch
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
except Exception as e:
|
||||
print_acc(f"Error processing image: {file_item.path}")
|
||||
print_acc(f"Error: {str(e)}")
|
||||
raise e
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
|
||||
# only process if not saved to disk
|
||||
if not os.path.exists(text_embedding_path):
|
||||
# load if not loaded
|
||||
if not did_move:
|
||||
self.sd.set_device_state_preset('cache_text_encoder')
|
||||
did_move = True
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
# save it
|
||||
prompt_embeds.save(text_embedding_path)
|
||||
del prompt_embeds
|
||||
file_item.is_text_embedding_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
# if did_move:
|
||||
# self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
|
||||
Reference in New Issue
Block a user