mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Some work on caching text embeddings
This commit is contained in:
@@ -816,6 +816,7 @@ class DatasetConfig:
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
|
||||
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
|
||||
|
||||
self.standardize_images: bool = kwargs.get('standardize_images', False)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import albumentations as A
|
||||
from toolkit import image_utils
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin, TextEmbeddingCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
@@ -378,7 +378,7 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, TextEmbeddingCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1780,6 +1780,8 @@ class TextEmbeddingCachingMixin:
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
||||
if self.is_caching_text_embeddings:
|
||||
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
|
||||
|
||||
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
||||
|
||||
|
||||
@@ -3020,6 +3020,8 @@ class StableDiffusion:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['cache_clip']:
|
||||
active_modules = ['clip']
|
||||
if device_state_preset in ['cache_text_encoder']:
|
||||
active_modules = ['text_encoder']
|
||||
if device_state_preset in ['unload']:
|
||||
active_modules = []
|
||||
if device_state_preset in ['generate']:
|
||||
|
||||
Reference in New Issue
Block a user