mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-09 12:39:49 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -13,8 +13,8 @@ from toolkit import image_utils
|
||||
from toolkit.basic import get_quick_signature_string
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
|
||||
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -32,6 +32,7 @@ def print_once(msg):
|
||||
|
||||
class FileItemDTO(
|
||||
LatentCachingFileItemDTOMixin,
|
||||
TextEmbeddingFileItemDTOMixin,
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
@@ -124,6 +125,7 @@ class FileItemDTO(
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_text_embedding()
|
||||
self.cleanup_control()
|
||||
self.cleanup_inpaint()
|
||||
self.cleanup_clip_image()
|
||||
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
|
||||
try:
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
if any([x.control_tensor is not None for x in self.file_items]):
|
||||
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
|
||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
||||
else:
|
||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
||||
|
||||
if any([x.prompt_embeds is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_prompt_embeds = None
|
||||
for x in self.file_items:
|
||||
if x.prompt_embeds is not None:
|
||||
base_prompt_embeds = x.prompt_embeds
|
||||
break
|
||||
prompt_embeds_list = []
|
||||
for x in self.file_items:
|
||||
if x.prompt_embeds is None:
|
||||
prompt_embeds_list.append(base_prompt_embeds)
|
||||
else:
|
||||
prompt_embeds_list.append(x.prompt_embeds)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
Reference in New Issue
Block a user