diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a8814f94..adf5dc6f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -816,6 +816,7 @@ class DatasetConfig: # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False) + self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False) self.standardize_images: bool = kwargs.get('standardize_images', False) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 54f5a98c..257cd462 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -19,7 +19,7 @@ import albumentations as A from toolkit import image_utils from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config -from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin, TextEmbeddingCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.print import print_acc from toolkit.accelerator import get_accelerator @@ -378,7 +378,7 @@ class PairedImageDataset(Dataset): return img, prompt, (self.neg_weight, self.pos_weight) -class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset): +class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, TextEmbeddingCachingMixin, BucketsMixin, CaptionMixin, Dataset): def __init__( self, diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 5a6aa239..d804ed98 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1780,6 +1780,8 @@ class TextEmbeddingCachingMixin: if hasattr(super(), '__init__'): super().__init__(**kwargs) self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + if self.is_caching_text_embeddings: + raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config") def cache_text_embeddings(self: 'AiToolkitDataset'): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index de8865f6..24959662 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -3020,6 +3020,8 @@ class StableDiffusion: active_modules = ['vae'] if device_state_preset in ['cache_clip']: active_modules = ['clip'] + if device_state_preset in ['cache_text_encoder']: + active_modules = ['text_encoder'] if device_state_preset in ['unload']: active_modules = [] if device_state_preset in ['generate']: