Some work on caching text embeddings

This commit is contained in:
Jaret Burkett
2025-07-26 09:22:04 -06:00
parent 0d89c44624
commit 77dc38a574
4 changed files with 7 additions and 2 deletions

View File

@@ -816,6 +816,7 @@ class DatasetConfig:
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
self.standardize_images: bool = kwargs.get('standardize_images', False)

View File

@@ -19,7 +19,7 @@ import albumentations as A
from toolkit import image_utils
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin, TextEmbeddingCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
@@ -378,7 +378,7 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, TextEmbeddingCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def __init__(
self,

View File

@@ -1780,6 +1780,8 @@ class TextEmbeddingCachingMixin:
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
if self.is_caching_text_embeddings:
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
def cache_text_embeddings(self: 'AiToolkitDataset'):

View File

@@ -3020,6 +3020,8 @@ class StableDiffusion:
active_modules = ['vae']
if device_state_preset in ['cache_clip']:
active_modules = ['clip']
if device_state_preset in ['cache_text_encoder']:
active_modules = ['text_encoder']
if device_state_preset in ['unload']:
active_modules = []
if device_state_preset in ['generate']: