Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -168,7 +168,9 @@ class QwenImageModel(BaseModel):
text_encoder = [pipe.text_encoder] text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer] tokenizer = [pipe.tokenizer]
pipe.transformer = pipe.transformer.to(self.device_torch) # leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush() flush()
# just to make sure everything is on the right device and dtype # just to make sure everything is on the right device and dtype
@@ -210,6 +212,7 @@ class QwenImageModel(BaseModel):
generator: torch.Generator, generator: torch.Generator,
extra: dict, extra: dict,
): ):
self.model.to(self.device_torch, dtype=self.torch_dtype)
control_img = None control_img = None
if gen_config.ctrl_img is not None: if gen_config.ctrl_img is not None:
raise NotImplementedError( raise NotImplementedError(

View File

@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, ConcatDataset
from toolkit import train_tools from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig from toolkit.config_modules import GenerateImageConfig
from toolkit.data_loader import get_dataloader_datasets from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder
def flush(): def flush():
@@ -108,6 +109,33 @@ class SDTrainer(BaseSDTrainProcess):
def before_model_load(self): def before_model_load(self):
pass pass
def cache_sample_prompts(self):
if self.train_config.disable_sampling:
return
if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0:
# cache all the samples
self.sd.sample_prompts_cache = []
sample_folder = os.path.join(self.save_root, 'samples')
output_path = os.path.join(sample_folder, 'test.jpg')
for i in range(len(self.sample_config.prompts)):
sample_item = self.sample_config.samples[i]
prompt = self.sample_config.prompts[i]
# needed so we can autoparse the prompt to handle flags
gen_img_config = GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
negative_prompt=sample_item.neg,
output_path=output_path,
)
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
self.sd.sample_prompts_cache.append({
'conditional': positive,
'unconditional': negative
})
def before_dataset_load(self): def before_dataset_load(self):
self.assistant_adapter = None self.assistant_adapter = None
@@ -143,6 +171,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self): def hook_before_train_loop(self):
super().hook_before_train_loop() super().hook_before_train_loop()
if self.is_caching_text_embeddings:
# make sure model is on cpu for this part so we don't oom.
self.sd.unet.to('cpu')
# cache unconditional embeds (blank prompt) # cache unconditional embeds (blank prompt)
with torch.no_grad(): with torch.no_grad():
@@ -195,15 +226,18 @@ class SDTrainer(BaseSDTrainProcess):
self.negative_prompt_pool = [self.train_config.negative_prompt] self.negative_prompt_pool = [self.train_config.negative_prompt]
# handle unload text encoder # handle unload text encoder
if self.train_config.unload_text_encoder: if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.no_grad(): with torch.no_grad():
if self.train_config.train_text_encoder: if self.train_config.train_text_encoder:
raise ValueError("Cannot unload text encoder if training text encoder") raise ValueError("Cannot unload text encoder if training text encoder")
# cache embeddings # cache embeddings
print_acc("\n***** UNLOADING TEXT ENCODER *****") print_acc("\n***** UNLOADING TEXT ENCODER *****")
print_acc("This will train only with a blank prompt or trigger word, if set") if self.is_caching_text_embeddings:
print_acc("If this is not what you want, remove the unload_text_encoder flag") print_acc("Embeddings cached to disk. We dont need the text encoder anymore")
else:
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
print_acc("***********************************") print_acc("***********************************")
print_acc("") print_acc("")
self.sd.text_encoder_to(self.device_torch) self.sd.text_encoder_to(self.device_torch)
@@ -212,9 +246,16 @@ class SDTrainer(BaseSDTrainProcess):
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
if self.train_config.diff_output_preservation: if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
self.cache_sample_prompts()
# move back to cpu # unload the text encoder
self.sd.text_encoder_to('cpu') if self.is_caching_text_embeddings:
unload_text_encoder(self.sd)
else:
# todo once every model is tested to work, unload properly. Though, this will all be merged into one thing.
# keep legacy usage for now.
self.sd.text_encoder_to("cpu")
flush() flush()
if self.train_config.diffusion_feature_extractor_path is not None: if self.train_config.diffusion_feature_extractor_path is not None:
@@ -923,11 +964,14 @@ class SDTrainer(BaseSDTrainProcess):
prompt = prompt.replace(trigger, class_name) prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt( if batch.prompt_embeds is not None:
prompt_list, embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
long_prompts=self.do_long_prompts).to( else:
self.device_torch, embeds_to_use = self.sd.encode_prompt(
dtype=dtype).detach() prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this # dont use network on this
# self.network.multiplier = 0.0 # self.network.multiplier = 0.0
@@ -1294,18 +1338,24 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
unconditional_embeds = None unconditional_embeds = None
if self.train_config.unload_text_encoder: if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.set_grad_enabled(False): with torch.set_grad_enabled(False):
embeds_to_use = self.cached_blank_embeds.clone().detach().to( if batch.prompt_embeds is not None:
self.device_torch, dtype=dtype # use the cached embeds
) conditional_embeds = batch.prompt_embeds.clone().detach().to(
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
self.device_torch, dtype=dtype self.device_torch, dtype=dtype
) )
conditional_embeds = concat_prompt_embeds( else:
[embeds_to_use] * noisy_latents.shape[0] embeds_to_use = self.cached_blank_embeds.clone().detach().to(
) self.device_torch, dtype=dtype
)
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
conditional_embeds = concat_prompt_embeds(
[embeds_to_use] * noisy_latents.shape[0]
)
if self.train_config.do_cfg: if self.train_config.do_cfg:
unconditional_embeds = self.cached_blank_embeds.clone().detach().to( unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype self.device_torch, dtype=dtype

View File

@@ -145,7 +145,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
raw_datasets = preprocess_dataset_raw_config(raw_datasets) raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None self.datasets = None
self.datasets_reg = None self.datasets_reg = None
self.dataset_configs: List[DatasetConfig] = []
self.params = [] self.params = []
# add dataset text embedding cache to their config
if self.train_config.cache_text_embeddings:
for raw_dataset in raw_datasets:
raw_dataset['cache_text_embeddings'] = True
if raw_datasets is not None and len(raw_datasets) > 0: if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets: for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset) dataset = DatasetConfig(**raw_dataset)
@@ -160,6 +167,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets is None: if self.datasets is None:
self.datasets = [] self.datasets = []
self.datasets.append(dataset) self.datasets.append(dataset)
self.dataset_configs.append(dataset)
self.is_caching_text_embeddings = any(
dataset.cache_text_embeddings for dataset in self.dataset_configs
)
# cannot train trigger word if caching text embeddings
if self.is_caching_text_embeddings and self.trigger_word is not None:
raise ValueError("Cannot train trigger word if caching text embeddings. Please remove the trigger word or disable text embedding caching.")
self.embed_config = None self.embed_config = None
embedding_raw = self.get_conf('embedding', None) embedding_raw = self.get_conf('embedding', None)
@@ -206,7 +222,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None, train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None, train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner, train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder, unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
require_grads=False # we ensure them later require_grads=False # we ensure them later
) )
@@ -220,7 +236,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None, train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None, train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner, train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder, unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
require_grads=True # We check for grads when getting params require_grads=True # We check for grads when getting params
) )
@@ -235,7 +251,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos: Union[LearnableSNRGamma, None] = None self.snr_gos: Union[LearnableSNRGamma, None] = None
self.ema: ExponentialMovingAverage = None self.ema: ExponentialMovingAverage = None
validate_configs(self.train_config, self.model_config, self.save_config) validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs)
do_profiler = self.get_conf('torch_profiler', False) do_profiler = self.get_conf('torch_profiler', False)
self.torch_profiler = None if not do_profiler else torch.profiler.profile( self.torch_profiler = None if not do_profiler else torch.profiler.profile(

View File

@@ -482,6 +482,8 @@ class TrainConfig:
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu # will cache a blank prompt or the trigger word, and unload the text encoder to cpu
# will make training faster and use less vram # will make training faster and use less vram
self.unload_text_encoder = kwargs.get('unload_text_encoder', False) self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
# will toggle all datasets to cache text embeddings
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
# for swapping which parameters are trained during training # for swapping which parameters are trained during training
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False) self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
@@ -1189,6 +1191,7 @@ def validate_configs(
train_config: TrainConfig, train_config: TrainConfig,
model_config: ModelConfig, model_config: ModelConfig,
save_config: SaveConfig, save_config: SaveConfig,
dataset_configs: List[DatasetConfig]
): ):
if model_config.is_flux: if model_config.is_flux:
if save_config.save_format != 'diffusers': if save_config.save_format != 'diffusers':
@@ -1200,3 +1203,18 @@ def validate_configs(
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss: if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. " raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.") "Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
# see if any datasets are caching text embeddings
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
if is_caching_text_embeddings:
# check if they are doing differential output preservation
if train_config.diff_output_preservation:
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
# make sure they are all cached
for dataset in dataset_configs:
if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")

View File

@@ -558,6 +558,8 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
self.cache_latents_all_latents() self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk: if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk() self.cache_clip_vision_to_disk()
if self.is_caching_text_embeddings:
self.cache_text_embeddings()
if self.is_generating_controls: if self.is_generating_controls:
# always do this last # always do this last
self.setup_controls() self.setup_controls()

View File

@@ -13,8 +13,8 @@ from toolkit import image_utils
from toolkit.basic import get_quick_signature_string from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig from toolkit.config_modules import DatasetConfig
@@ -32,6 +32,7 @@ def print_once(msg):
class FileItemDTO( class FileItemDTO(
LatentCachingFileItemDTOMixin, LatentCachingFileItemDTOMixin,
TextEmbeddingFileItemDTOMixin,
CaptionProcessingDTOMixin, CaptionProcessingDTOMixin,
ImageProcessingDTOMixin, ImageProcessingDTOMixin,
ControlFileItemDTOMixin, ControlFileItemDTOMixin,
@@ -124,6 +125,7 @@ class FileItemDTO(
def cleanup(self): def cleanup(self):
self.tensor = None self.tensor = None
self.cleanup_latent() self.cleanup_latent()
self.cleanup_text_embedding()
self.cleanup_control() self.cleanup_control()
self.cleanup_inpaint() self.cleanup_inpaint()
self.cleanup_clip_image() self.cleanup_clip_image()
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
try: try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached is_latents_cached = self.file_items[0].is_latent_cached
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
self.tensor: Union[torch.Tensor, None] = None self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
if is_latents_cached: if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None: # if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them # if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]): if any([x.control_tensor is not None for x in self.file_items]):
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else: else:
raise Exception("clip_image_embeds_unconditional is None for some file items") raise Exception("clip_image_embeds_unconditional is None for some file items")
if any([x.prompt_embeds is not None for x in self.file_items]):
# find one to use as a base
base_prompt_embeds = None
for x in self.file_items:
if x.prompt_embeds is not None:
base_prompt_embeds = x.prompt_embeds
break
prompt_embeds_list = []
for x in self.file_items:
if x.prompt_embeds is None:
prompt_embeds_list.append(base_prompt_embeds)
else:
prompt_embeds_list.append(x.prompt_embeds)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
except Exception as e: except Exception as e:
print(e) print(e)

View File

@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
import albumentations as A import albumentations as A
from toolkit.print import print_acc from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
self.extra_values: List[float] = dataset_config.extra_values self.extra_values: List[float] = dataset_config.extra_values
# todo allow for loading from sd-scripts style dict # todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
if self.raw_caption is not None: if self.raw_caption is not None:
# we already loaded it # we already loaded it
pass pass
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
if self.dataset_config.num_frames > 1: if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents) self.load_and_process_video(transform, only_load_latents)
return return
# handle get_prompt_embedding
if self.is_text_embedding_cached:
self.load_prompt_embedding()
# if we are caching latents, just do that # if we are caching latents, just do that
if self.is_latent_cached: if self.is_latent_cached:
self.get_latent() self.get_latent()
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
self.sd.restore_device_state() self.sd.restore_device_state()
class TextEmbeddingFileItemDTOMixin:
def __init__(self, *args, **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.prompt_embeds: Union[PromptEmbeds, None] = None
self._text_embedding_path: Union[str, None] = None
self.is_text_embedding_cached = False
self.text_embedding_load_device = 'cpu'
self.text_embedding_space_version = 'sd1'
self.text_embedding_version = 1
def get_text_embedding_info_dict(self: 'FileItemDTO'):
# make sure the caption is loaded here
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
if self.caption is None:
self.load_caption()
# throw error is [trigger] in caption as we cannot inject it while caching
if '[trigger]' in self.caption:
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
item = OrderedDict([
("caption", self.caption),
("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version),
])
return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
if self._text_embedding_path is not None and not recalculate:
return self._text_embedding_path
else:
# we store text embeddings in a folder in same path as image called _text_embedding_cache
img_dir = os.path.dirname(self.path)
te_dir = os.path.join(img_dir, '_t_e_cache')
hash_dict = self.get_text_embedding_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._text_embedding_path
def cleanup_text_embedding(self):
if self.prompt_embeds is not None:
# we are caching on disk, don't save in memory
self.prompt_embeds = None
def load_prompt_embedding(self, device=None):
if not self.is_text_embedding_cached:
return
if self.prompt_embeds is None:
# load it from disk
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
class TextEmbeddingCachingMixin: class TextEmbeddingCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs): def __init__(self: 'AiToolkitDataset', **kwargs):
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
if hasattr(super(), '__init__'): if hasattr(super(), '__init__'):
super().__init__(**kwargs) super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
if self.is_caching_text_embeddings:
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
def cache_text_embeddings(self: 'AiToolkitDataset'): def cache_text_embeddings(self: 'AiToolkitDataset'):
with accelerator.main_process_first(): with accelerator.main_process_first():
print_acc(f"Caching text_embeddings for {self.dataset_path}") print_acc(f"Caching text_embeddings for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
print_acc(" - Saving text embeddings to disk") print_acc(" - Saving text embeddings to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents') did_move = False
# use tqdm to show progress # use tqdm to show progress
i = 0 i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
# set latent space version file_item.text_embedding_space_version = self.sd.model_config.arch
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
file_item.latent_load_device = self.sd.device file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True) text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
# check if it is saved to disk already # only process if not saved to disk
if os.path.exists(latent_path): if not os.path.exists(text_embedding_path):
if to_memory: # load if not loaded
# load it into memory if not did_move:
state_dict = load_file(latent_path, device='cpu') self.sd.set_device_state_preset('cache_text_encoder')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) did_move = True
else: prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# not saved to disk, calculate # save it
# load the image first prompt_embeds.save(text_embedding_path)
file_item.load_and_process_image(self.transform, only_load_latents=True) del prompt_embeds
dtype = self.sd.torch_dtype file_item.is_text_embedding_cached = True
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1 i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state # restore device state
self.sd.restore_device_state() # if did_move:
# self.sd.restore_device_state()
class CLIPCachingMixin: class CLIPCachingMixin:

View File

@@ -168,6 +168,8 @@ class BaseModel:
self._after_sample_img_hooks = [] self._after_sample_img_hooks = []
self._status_update_hooks = [] self._status_update_hooks = []
self.is_transformer = False self.is_transformer = False
self.sample_prompts_cache = None
# properties for old arch for backwards compatibility # properties for old arch for backwards compatibility
@property @property
@@ -484,19 +486,23 @@ class BaseModel:
quad_count=4 quad_count=4
) )
# encode the prompt ourselves so we can do fun stuff with embeddings if self.sample_prompts_cache is not None:
if isinstance(self.adapter, CustomAdapter): conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
self.adapter.is_unconditional_run = False unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
conditional_embeds = self.encode_prompt( else:
gen_config.prompt, gen_config.prompt_2, force_all=True) # encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt( unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
) )
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings # allow any manipulations to take place to embeddings
gen_config.post_process_embeddings( gen_config.post_process_embeddings(

View File

@@ -92,6 +92,56 @@ class PromptEmbeds:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1) pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
return pe return pe
def save(self, path: str):
"""
Save the prompt embeds to a file.
:param path: The path to save the prompt embeds.
"""
pe = self.clone()
state_dict = {}
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
for i, text_embed in enumerate(pe.text_embeds):
state_dict[f"text_embed_{i}"] = text_embed.cpu()
else:
state_dict["text_embed"] = pe.text_embeds.cpu()
if pe.pooled_embeds is not None:
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
if pe.attention_mask is not None:
state_dict["attention_mask"] = pe.attention_mask.cpu()
os.makedirs(os.path.dirname(path), exist_ok=True)
save_file(state_dict, path)
@classmethod
def load(cls, path: str) -> 'PromptEmbeds':
"""
Load the prompt embeds from a file.
:param path: The path to load the prompt embeds from.
:return: An instance of PromptEmbeds.
"""
state_dict = load_file(path, device='cpu')
text_embeds = []
pooled_embeds = None
attention_mask = None
for key in sorted(state_dict.keys()):
if key.startswith("text_embed_"):
text_embeds.append(state_dict[key])
elif key == "text_embed":
text_embeds.append(state_dict[key])
elif key == "pooled_embed":
pooled_embeds = state_dict[key]
elif key == "attention_mask":
attention_mask = state_dict[key]
pe = cls(None)
pe.text_embeds = text_embeds
if len(text_embeds) == 1:
pe.text_embeds = text_embeds[0]
if pooled_embeds is not None:
pe.pooled_embeds = pooled_embeds
if attention_mask is not None:
pe.attention_mask = attention_mask
return pe
class EncodedPromptPair: class EncodedPromptPair:
def __init__( def __init__(

View File

@@ -209,6 +209,8 @@ class StableDiffusion:
# todo update this based on the model # todo update this based on the model
self.is_transformer = False self.is_transformer = False
self.sample_prompts_cache = None
# properties for old arch for backwards compatibility # properties for old arch for backwards compatibility
@property @property
def is_xl(self): def is_xl(self):
@@ -1426,18 +1428,22 @@ class StableDiffusion:
quad_count=4 quad_count=4
) )
# encode the prompt ourselves so we can do fun stuff with embeddings if self.sample_prompts_cache is not None:
if isinstance(self.adapter, CustomAdapter): conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
self.adapter.is_unconditional_run = False unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) else:
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt( unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
) )
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings # allow any manipulations to take place to embeddings
gen_config.post_process_embeddings( gen_config.post_process_embeddings(

63
toolkit/unloader.py Normal file
View File

@@ -0,0 +1,63 @@
import torch
from toolkit.basic import flush
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
class FakeTextEncoder(torch.nn.Module):
def __init__(self, device, dtype):
super().__init__()
# register a dummy parameter to avoid errors in some cases
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
self._device = device
self._dtype = dtype
def forward(self, *args, **kwargs):
raise NotImplementedError(
"This is a fake text encoder and should not be used for inference."
)
return None
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(self, *args, **kwargs):
return self
def unload_text_encoder(model: "BaseModel"):
# unload the text encoder in a way that will work with all models and will not throw errors
# we need to make it appear as a text encoder module without actually having one so all
# to functions and what not will work.
if model.text_encoder is not None:
if isinstance(model.text_encoder, list):
text_encoder_list = []
pipe = model.pipeline
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
if hasattr(pipe, "text_encoder"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
pipe.text_encoder = te
i = 2
while hasattr(pipe, f"text_encoder_{i}"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
setattr(pipe, f"text_encoder_{i}", te)
i += 1
model.text_encoder = text_encoder_list
else:
# only has a single text encoder
model.text_encoder = FakeTextEncoder()
flush()

View File

@@ -389,22 +389,40 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
/> />
</FormGroup> </FormGroup>
<NumberInput {jobConfig.config.process[0].train.ema_config?.use_ema && (
label="EMA Decay" <NumberInput
className="pt-2" label="EMA Decay"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number} className="pt-2"
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
placeholder="eg. 0.99" onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
min={0} placeholder="eg. 0.99"
/> min={0}
<FormGroup label="Unload Text Encoder" className="pt-2"> />
<div className="grid grid-cols-2 gap-2"> )}
<Checkbox
label="Unload TE" <FormGroup label="Text Encoder Optimizations" className="pt-2">
checked={jobConfig.config.process[0].train.unload_text_encoder || false} <Checkbox
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')} label="Unload TE"
/> checked={jobConfig.config.process[0].train.unload_text_encoder || false}
</div> docKey={'train.unload_text_encoder'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
if (value) {
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
}
}}
/>
<Checkbox
label="Cache Text Embeddings"
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
docKey={'train.cache_text_embeddings'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
if (value) {
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
}
}}
/>
</FormGroup> </FormGroup>
</div> </div>
<div> <div>
@@ -416,21 +434,27 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')} onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/> />
</FormGroup> </FormGroup>
<NumberInput {jobConfig.config.process[0].train.diff_output_preservation && (
label="DOP Loss Multiplier" <>
className="pt-2" <NumberInput
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number} label="DOP Loss Multiplier"
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')} className="pt-2"
placeholder="eg. 1.0" value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
min={0} onChange={value =>
/> setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
<TextInput }
label="DOP Preservation Class" placeholder="eg. 1.0"
className="pt-2" min={0}
value={jobConfig.config.process[0].train.diff_output_preservation_class as string} />
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} <TextInput
placeholder="eg. woman" label="DOP Preservation Class"
/> className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
</>
)}
</div> </div>
</div> </div>
</Card> </Card>
@@ -524,16 +548,14 @@ export default function SimpleJob({
checked={dataset.is_reg || false} checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/> />
{ {modelArch?.additionalSections?.includes('datasets.do_i2v') && (
modelArch?.additionalSections?.includes('datasets.do_i2v') && ( <Checkbox
<Checkbox label="Do I2V"
label="Do I2V" checked={dataset.do_i2v || false}
checked={dataset.do_i2v || false} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} docKey="datasets.do_i2v"
docKey="datasets.do_i2v" />
/> )}
)
}
</FormGroup> </FormGroup>
</div> </div>
<div> <div>

View File

@@ -66,6 +66,7 @@ export const defaultJobConfig: JobConfig = {
weight_decay: 1e-4, weight_decay: 1e-4,
}, },
unload_text_encoder: false, unload_text_encoder: false,
cache_text_embeddings: false,
lr: 0.0001, lr: 0.0001,
ema_config: { ema_config: {
use_ema: false, use_ema: false,

View File

@@ -12,12 +12,12 @@ const docs: { [key: string]: ConfigDoc } = {
</> </>
), ),
}, },
'gpuids': { gpuids: {
title: 'GPU ID', title: 'GPU ID',
description: ( description: (
<> <>
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently. This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
However, you can start multiple jobs in parallel, each using a different GPU. However, you can start multiple jobs in parallel, each using a different GPU.
</> </>
), ),
}, },
@@ -25,17 +25,19 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Trigger Word', title: 'Trigger Word',
description: ( description: (
<> <>
Optional: This will be the word or token used to trigger your concept or character. Optional: This will be the word or token used to trigger your concept or character.
<br /> <br />
<br /> <br />
When using a trigger word, When using a trigger word, If your captions do not contain the trigger word, it will be added automatically the
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have beginning of the caption. If you do not have captions, the caption will become just the trigger word. If you
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots, want to have variable trigger words in your captions to put it in different spots, you can use the{' '}
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word. <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger
<br /> word.
<br /> <br />
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the <br />
<code>{'[trigger]'}</code> placeholder in your test prompts as well. Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger
word manually or use the
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
</> </>
), ),
}, },
@@ -43,8 +45,9 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Name or Path', title: 'Name or Path',
description: ( description: (
<> <>
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here. folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the
path to an all in one safetensors checkpoint here.
</> </>
), ),
}, },
@@ -52,8 +55,8 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Control Dataset', title: 'Control Dataset',
description: ( description: (
<> <>
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs. The control dataset needs to have files that match the filenames of your training dataset. They should be
These images are fed as control/input images during training. matching file pairs. These images are fed as control/input images during training.
</> </>
), ),
}, },
@@ -61,16 +64,19 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Number of Frames', title: 'Number of Frames',
description: ( description: (
<> <>
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame. This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset. for one frame. If your dataset is only videos, frames will be extracted evenly spaced from the videos in the
<br/> dataset.
<br/> <br />
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video. <br />
So you would want all of your videos trimmed to around 5 seconds for best results. It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames
<br/> will result in a 5 second video. So you would want all of your videos trimmed to around 5 seconds for best
<br/> results.
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81 <br />
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast. <br />
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long,
will result in 81 evenly spaced frames for each video making the 2 second video appear slow and the 90second
video appear very fast.
</> </>
), ),
}, },
@@ -78,9 +84,30 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Do I2V', title: 'Do I2V',
description: ( description: (
<> <>
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image dataset to be trained as an I2V dataset. This means that the first frame will be extracted from the video and
for the video. If this option is not set, the dataset will be treated as a T2V dataset. used as the start image for the video. If this option is not set, the dataset will be treated as a T2V dataset.
</>
),
},
'train.unload_text_encoder': {
title: 'Unload Text Encoder',
description: (
<>
Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the
GPU. Captions in for the dataset will be ignored
</>
),
},
'train.cache_text_embeddings': {
title: 'Cache Text Embeddings',
description: (
<>
<small>(experimental)</small>
<br />
Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The
text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt
such as trigger words, caption dropout, etc.
</> </>
), ),
}, },

View File

@@ -110,6 +110,7 @@ export interface TrainConfig {
ema_config?: EMAConfig; ema_config?: EMAConfig;
dtype: string; dtype: string;
unload_text_encoder: boolean; unload_text_encoder: boolean;
cache_text_embeddings: boolean;
optimizer_params: { optimizer_params: {
weight_decay: number; weight_decay: number;
}; };

View File

@@ -1 +1 @@
VERSION = "0.3.18" VERSION = "0.4.0"