Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -13,8 +13,8 @@ from toolkit import image_utils
from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -32,6 +32,7 @@ def print_once(msg):
class FileItemDTO(
LatentCachingFileItemDTOMixin,
TextEmbeddingFileItemDTOMixin,
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
@@ -124,6 +125,7 @@ class FileItemDTO(
def cleanup(self):
self.tensor = None
self.cleanup_latent()
self.cleanup_text_embedding()
self.cleanup_control()
self.cleanup_inpaint()
self.cleanup_clip_image()
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
if any([x.prompt_embeds is not None for x in self.file_items]):
# find one to use as a base
base_prompt_embeds = None
for x in self.file_items:
if x.prompt_embeds is not None:
base_prompt_embeds = x.prompt_embeds
break
prompt_embeds_list = []
for x in self.file_items:
if x.prompt_embeds is None:
prompt_embeds_list.append(base_prompt_embeds)
else:
prompt_embeds_list.append(x.prompt_embeds)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
except Exception as e:
print(e)