mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Some work on caching text embeddings
This commit is contained in:
@@ -816,6 +816,7 @@ class DatasetConfig:
|
|||||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||||
self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
|
self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
|
||||||
|
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
|
||||||
|
|
||||||
self.standardize_images: bool = kwargs.get('standardize_images', False)
|
self.standardize_images: bool = kwargs.get('standardize_images', False)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import albumentations as A
|
|||||||
from toolkit import image_utils
|
from toolkit import image_utils
|
||||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
|
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin, TextEmbeddingCachingMixin
|
||||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||||
from toolkit.print import print_acc
|
from toolkit.print import print_acc
|
||||||
from toolkit.accelerator import get_accelerator
|
from toolkit.accelerator import get_accelerator
|
||||||
@@ -378,7 +378,7 @@ class PairedImageDataset(Dataset):
|
|||||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||||
|
|
||||||
|
|
||||||
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, TextEmbeddingCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1780,6 +1780,8 @@ class TextEmbeddingCachingMixin:
|
|||||||
if hasattr(super(), '__init__'):
|
if hasattr(super(), '__init__'):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
||||||
|
if self.is_caching_text_embeddings:
|
||||||
|
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
|
||||||
|
|
||||||
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
||||||
|
|
||||||
|
|||||||
@@ -3020,6 +3020,8 @@ class StableDiffusion:
|
|||||||
active_modules = ['vae']
|
active_modules = ['vae']
|
||||||
if device_state_preset in ['cache_clip']:
|
if device_state_preset in ['cache_clip']:
|
||||||
active_modules = ['clip']
|
active_modules = ['clip']
|
||||||
|
if device_state_preset in ['cache_text_encoder']:
|
||||||
|
active_modules = ['text_encoder']
|
||||||
if device_state_preset in ['unload']:
|
if device_state_preset in ['unload']:
|
||||||
active_modules = []
|
active_modules = []
|
||||||
if device_state_preset in ['generate']:
|
if device_state_preset in ['generate']:
|
||||||
|
|||||||
Reference in New Issue
Block a user