mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Shrink text embeds to max token length for LTX-2. Drastically reduces cached text embedding sizes
This commit is contained in:
@@ -1721,8 +1721,6 @@ class LatentCachingFileItemDTOMixin:
|
||||
self.is_caching_to_disk = False
|
||||
self.is_caching_to_memory = False
|
||||
self.latent_load_device = 'cpu'
|
||||
# sd1 or sdxl or others
|
||||
self.latent_space_version = 'sd1'
|
||||
# todo, increment this if we change the latent format to invalidate cache
|
||||
self.latent_version = 1
|
||||
|
||||
@@ -1829,21 +1827,6 @@ class LatentCachingMixin:
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = self.sd.model_config.arch
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
@@ -1933,7 +1916,6 @@ class TextEmbeddingFileItemDTOMixin:
|
||||
self._text_embedding_path: Union[str, None] = None
|
||||
self.is_text_embedding_cached = False
|
||||
self.text_embedding_load_device = 'cpu'
|
||||
self.text_embedding_space_version = 'sd1'
|
||||
self.text_embedding_version = 1
|
||||
|
||||
def get_text_embedding_info_dict(self: 'FileItemDTO'):
|
||||
@@ -1997,7 +1979,6 @@ class TextEmbeddingCachingMixin:
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
|
||||
file_item.text_embedding_space_version = self.sd.model_config.arch
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
|
||||
|
||||
Reference in New Issue
Block a user