diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index bcd42ed3..b4491962 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -168,7 +168,9 @@ class QwenImageModel(BaseModel): text_encoder = [pipe.text_encoder] tokenizer = [pipe.tokenizer] - pipe.transformer = pipe.transformer.to(self.device_torch) + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) flush() # just to make sure everything is on the right device and dtype @@ -210,6 +212,7 @@ class QwenImageModel(BaseModel): generator: torch.Generator, extra: dict, ): + self.model.to(self.device_torch, dtype=self.torch_dtype) control_img = None if gen_config.ctrl_img is not None: raise NotImplementedError( diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ca808823..22fd465a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, ConcatDataset from toolkit import train_tools from toolkit.basic import value_map, adain, get_mean_std from toolkit.clip_vision_adapter import ClipVisionAdapter -from toolkit.config_modules import GuidanceConfig +from toolkit.config_modules import GenerateImageConfig from toolkit.data_loader import get_dataloader_datasets from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType @@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss import torch.nn.functional as F +from toolkit.unloader import unload_text_encoder def flush(): @@ -108,6 +109,33 @@ class SDTrainer(BaseSDTrainProcess): def before_model_load(self): pass + + def cache_sample_prompts(self): + if self.train_config.disable_sampling: + return + if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0: + # cache all the samples + self.sd.sample_prompts_cache = [] + sample_folder = os.path.join(self.save_root, 'samples') + output_path = os.path.join(sample_folder, 'test.jpg') + for i in range(len(self.sample_config.prompts)): + sample_item = self.sample_config.samples[i] + prompt = self.sample_config.prompts[i] + + # needed so we can autoparse the prompt to handle flags + gen_img_config = GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + negative_prompt=sample_item.neg, + output_path=output_path, + ) + positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') + negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') + + self.sd.sample_prompts_cache.append({ + 'conditional': positive, + 'unconditional': negative + }) + def before_dataset_load(self): self.assistant_adapter = None @@ -143,6 +171,9 @@ class SDTrainer(BaseSDTrainProcess): def hook_before_train_loop(self): super().hook_before_train_loop() + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to('cpu') # cache unconditional embeds (blank prompt) with torch.no_grad(): @@ -195,15 +226,18 @@ class SDTrainer(BaseSDTrainProcess): self.negative_prompt_pool = [self.train_config.negative_prompt] # handle unload text encoder - if self.train_config.unload_text_encoder: + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: with torch.no_grad(): if self.train_config.train_text_encoder: raise ValueError("Cannot unload text encoder if training text encoder") # cache embeddings print_acc("\n***** UNLOADING TEXT ENCODER *****") - print_acc("This will train only with a blank prompt or trigger word, if set") - print_acc("If this is not what you want, remove the unload_text_encoder flag") + if self.is_caching_text_embeddings: + print_acc("Embeddings cached to disk. We dont need the text encoder anymore") + else: + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") print_acc("***********************************") print_acc("") self.sd.text_encoder_to(self.device_torch) @@ -212,9 +246,16 @@ class SDTrainer(BaseSDTrainProcess): self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) if self.train_config.diff_output_preservation: self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) + + self.cache_sample_prompts() - # move back to cpu - self.sd.text_encoder_to('cpu') + # unload the text encoder + if self.is_caching_text_embeddings: + unload_text_encoder(self.sd) + else: + # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. + # keep legacy usage for now. + self.sd.text_encoder_to("cpu") flush() if self.train_config.diffusion_feature_extractor_path is not None: @@ -923,11 +964,14 @@ class SDTrainer(BaseSDTrainProcess): prompt = prompt.replace(trigger, class_name) prompt_list[idx] = prompt - embeds_to_use = self.sd.encode_prompt( - prompt_list, - long_prompts=self.do_long_prompts).to( - self.device_torch, - dtype=dtype).detach() + if batch.prompt_embeds is not None: + embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) + else: + embeds_to_use = self.sd.encode_prompt( + prompt_list, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype).detach() # dont use network on this # self.network.multiplier = 0.0 @@ -1294,18 +1338,24 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_prompt'): unconditional_embeds = None - if self.train_config.unload_text_encoder: + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: with torch.set_grad_enabled(False): - embeds_to_use = self.cached_blank_embeds.clone().detach().to( - self.device_torch, dtype=dtype - ) - if self.cached_trigger_embeds is not None and not is_reg: - embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + if batch.prompt_embeds is not None: + # use the cached embeds + conditional_embeds = batch.prompt_embeds.clone().detach().to( self.device_torch, dtype=dtype ) - conditional_embeds = concat_prompt_embeds( - [embeds_to_use] * noisy_latents.shape[0] - ) + else: + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + if self.cached_trigger_embeds is not None and not is_reg: + embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + conditional_embeds = concat_prompt_embeds( + [embeds_to_use] * noisy_latents.shape[0] + ) if self.train_config.do_cfg: unconditional_embeds = self.cached_blank_embeds.clone().detach().to( self.device_torch, dtype=dtype diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5da603ef..156ed766 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -145,7 +145,14 @@ class BaseSDTrainProcess(BaseTrainProcess): raw_datasets = preprocess_dataset_raw_config(raw_datasets) self.datasets = None self.datasets_reg = None + self.dataset_configs: List[DatasetConfig] = [] self.params = [] + + # add dataset text embedding cache to their config + if self.train_config.cache_text_embeddings: + for raw_dataset in raw_datasets: + raw_dataset['cache_text_embeddings'] = True + if raw_datasets is not None and len(raw_datasets) > 0: for raw_dataset in raw_datasets: dataset = DatasetConfig(**raw_dataset) @@ -160,6 +167,15 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.datasets is None: self.datasets = [] self.datasets.append(dataset) + self.dataset_configs.append(dataset) + + self.is_caching_text_embeddings = any( + dataset.cache_text_embeddings for dataset in self.dataset_configs + ) + + # cannot train trigger word if caching text embeddings + if self.is_caching_text_embeddings and self.trigger_word is not None: + raise ValueError("Cannot train trigger word if caching text embeddings. Please remove the trigger word or disable text embedding caching.") self.embed_config = None embedding_raw = self.get_conf('embedding', None) @@ -206,7 +222,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_embedding=self.embed_config is not None, train_decorator=self.decorator_config is not None, train_refiner=self.train_config.train_refiner, - unload_text_encoder=self.train_config.unload_text_encoder, + unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings, require_grads=False # we ensure them later ) @@ -220,7 +236,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_embedding=self.embed_config is not None, train_decorator=self.decorator_config is not None, train_refiner=self.train_config.train_refiner, - unload_text_encoder=self.train_config.unload_text_encoder, + unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings, require_grads=True # We check for grads when getting params ) @@ -235,7 +251,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.snr_gos: Union[LearnableSNRGamma, None] = None self.ema: ExponentialMovingAverage = None - validate_configs(self.train_config, self.model_config, self.save_config) + validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs) do_profiler = self.get_conf('torch_profiler', False) self.torch_profiler = None if not do_profiler else torch.profiler.profile( diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 98454b74..e5a5633d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -482,6 +482,8 @@ class TrainConfig: # will cache a blank prompt or the trigger word, and unload the text encoder to cpu # will make training faster and use less vram self.unload_text_encoder = kwargs.get('unload_text_encoder', False) + # will toggle all datasets to cache text embeddings + self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False) # for swapping which parameters are trained during training self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False) # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more @@ -1189,6 +1191,7 @@ def validate_configs( train_config: TrainConfig, model_config: ModelConfig, save_config: SaveConfig, + dataset_configs: List[DatasetConfig] ): if model_config.is_flux: if save_config.save_format != 'diffusers': @@ -1200,3 +1203,18 @@ def validate_configs( if train_config.bypass_guidance_embedding and train_config.do_guidance_loss: raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. " "Please set bypass_guidance_embedding to False or do_guidance_loss to False.") + + # see if any datasets are caching text embeddings + is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs) + if is_caching_text_embeddings: + + # check if they are doing differential output preservation + if train_config.diff_output_preservation: + raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.") + + # make sure they are all cached + for dataset in dataset_configs: + if not dataset.cache_text_embeddings: + raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.") + + diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 257cd462..43f03647 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -558,6 +558,8 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin self.cache_latents_all_latents() if self.is_caching_clip_vision_to_disk: self.cache_clip_vision_to_disk() + if self.is_caching_text_embeddings: + self.cache_text_embeddings() if self.is_generating_controls: # always do this last self.setup_controls() diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index bcc6c918..c37bec1c 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -13,8 +13,8 @@ from toolkit import image_utils from toolkit.basic import get_quick_signature_string from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ - UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin - + UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -32,6 +32,7 @@ def print_once(msg): class FileItemDTO( LatentCachingFileItemDTOMixin, + TextEmbeddingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin, ControlFileItemDTOMixin, @@ -124,6 +125,7 @@ class FileItemDTO( def cleanup(self): self.tensor = None self.cleanup_latent() + self.cleanup_text_embedding() self.cleanup_control() self.cleanup_inpaint() self.cleanup_clip_image() @@ -136,6 +138,7 @@ class DataLoaderBatchDTO: try: self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) is_latents_cached = self.file_items[0].is_latent_cached + is_text_embedding_cached = self.file_items[0].is_text_embedding_cached self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None @@ -156,6 +159,7 @@ class DataLoaderBatchDTO: if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) self.control_tensor: Union[torch.Tensor, None] = None + self.prompt_embeds: Union[PromptEmbeds, None] = None # if self.file_items[0].control_tensor is not None: # if any have a control tensor, we concatenate them if any([x.control_tensor is not None for x in self.file_items]): @@ -268,6 +272,22 @@ class DataLoaderBatchDTO: self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) else: raise Exception("clip_image_embeds_unconditional is None for some file items") + + if any([x.prompt_embeds is not None for x in self.file_items]): + # find one to use as a base + base_prompt_embeds = None + for x in self.file_items: + if x.prompt_embeds is not None: + base_prompt_embeds = x.prompt_embeds + break + prompt_embeds_list = [] + for x in self.file_items: + if x.prompt_embeds is None: + prompt_embeds_list.append(base_prompt_embeds) + else: + prompt_embeds_list.append(x.prompt_embeds) + self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list) + except Exception as e: print(e) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index d804ed98..e39be2b9 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose import albumentations as A from toolkit.print import print_acc from toolkit.accelerator import get_accelerator +from toolkit.prompt_utils import PromptEmbeds from toolkit.train_tools import get_torch_dtype @@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin: self.extra_values: List[float] = dataset_config.extra_values # todo allow for loading from sd-scripts style dict - def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): + def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None): if self.raw_caption is not None: # we already loaded it pass @@ -635,6 +636,9 @@ class ImageProcessingDTOMixin: if self.dataset_config.num_frames > 1: self.load_and_process_video(transform, only_load_latents) return + # handle get_prompt_embedding + if self.is_text_embedding_cached: + self.load_prompt_embedding() # if we are caching latents, just do that if self.is_latent_cached: self.get_latent() @@ -1773,6 +1777,61 @@ class LatentCachingMixin: self.sd.restore_device_state() +class TextEmbeddingFileItemDTOMixin: + def __init__(self, *args, **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.prompt_embeds: Union[PromptEmbeds, None] = None + self._text_embedding_path: Union[str, None] = None + self.is_text_embedding_cached = False + self.text_embedding_load_device = 'cpu' + self.text_embedding_space_version = 'sd1' + self.text_embedding_version = 1 + + def get_text_embedding_info_dict(self: 'FileItemDTO'): + # make sure the caption is loaded here + # TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible. + if self.caption is None: + self.load_caption() + # throw error is [trigger] in caption as we cannot inject it while caching + if '[trigger]' in self.caption: + raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.") + item = OrderedDict([ + ("caption", self.caption), + ("text_embedding_space_version", self.text_embedding_space_version), + ("text_embedding_version", self.text_embedding_version), + ]) + return item + + def get_text_embedding_path(self: 'FileItemDTO', recalculate=False): + if self._text_embedding_path is not None and not recalculate: + return self._text_embedding_path + else: + # we store text embeddings in a folder in same path as image called _text_embedding_cache + img_dir = os.path.dirname(self.path) + te_dir = os.path.join(img_dir, '_t_e_cache') + hash_dict = self.get_text_embedding_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._text_embedding_path + + def cleanup_text_embedding(self): + if self.prompt_embeds is not None: + # we are caching on disk, don't save in memory + self.prompt_embeds = None + + def load_prompt_embedding(self, device=None): + if not self.is_text_embedding_cached: + return + if self.prompt_embeds is None: + # load it from disk + self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path()) class TextEmbeddingCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): @@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin: if hasattr(super(), '__init__'): super().__init__(**kwargs) self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings - if self.is_caching_text_embeddings: - raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config") def cache_text_embeddings(self: 'AiToolkitDataset'): - with accelerator.main_process_first(): print_acc(f"Caching text_embeddings for {self.dataset_path}") - # cache all latents to disk - to_disk = self.is_caching_latents_to_disk - to_memory = self.is_caching_latents_to_memory print_acc(" - Saving text embeddings to disk") - # move sd items to cpu except for vae - self.sd.set_device_state_preset('cache_latents') + + did_move = False # use tqdm to show progress i = 0 - for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): - # set latent space version - if self.sd.model_config.latent_space_version is not None: - file_item.latent_space_version = self.sd.model_config.latent_space_version - elif self.sd.is_xl: - file_item.latent_space_version = 'sdxl' - elif self.sd.is_v3: - file_item.latent_space_version = 'sd3' - elif self.sd.is_auraflow: - file_item.latent_space_version = 'sdxl' - elif self.sd.is_flux: - file_item.latent_space_version = 'flux1' - elif self.sd.model_config.is_pixart_sigma: - file_item.latent_space_version = 'sdxl' - else: - file_item.latent_space_version = self.sd.model_config.arch - file_item.is_caching_to_disk = to_disk - file_item.is_caching_to_memory = to_memory + for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'): + file_item.text_embedding_space_version = self.sd.model_config.arch file_item.latent_load_device = self.sd.device - latent_path = file_item.get_latent_path(recalculate=True) - # check if it is saved to disk already - if os.path.exists(latent_path): - if to_memory: - # load it into memory - state_dict = load_file(latent_path, device='cpu') - file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) - else: - # not saved to disk, calculate - # load the image first - file_item.load_and_process_image(self.transform, only_load_latents=True) - dtype = self.sd.torch_dtype - device = self.sd.device_torch - # add batch dimension - try: - imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) - latent = self.sd.encode_images(imgs).squeeze(0) - except Exception as e: - print_acc(f"Error processing image: {file_item.path}") - print_acc(f"Error: {str(e)}") - raise e - # save_latent - if to_disk: - state_dict = OrderedDict([ - ('latent', latent.clone().detach().cpu()), - ]) - # metadata - meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) - os.makedirs(os.path.dirname(latent_path), exist_ok=True) - save_file(state_dict, latent_path, metadata=meta) - - if to_memory: - # keep it in memory - file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) - - del imgs - del latent - del file_item.tensor - - # flush(garbage_collect=False) - file_item.is_latent_cached = True + text_embedding_path = file_item.get_text_embedding_path(recalculate=True) + # only process if not saved to disk + if not os.path.exists(text_embedding_path): + # load if not loaded + if not did_move: + self.sd.set_device_state_preset('cache_text_encoder') + did_move = True + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) + # save it + prompt_embeds.save(text_embedding_path) + del prompt_embeds + file_item.is_text_embedding_cached = True i += 1 - # flush every 100 - # if i % 100 == 0: - # flush() - # restore device state - self.sd.restore_device_state() + # if did_move: + # self.sd.restore_device_state() class CLIPCachingMixin: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 150c431d..fc0cb764 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -168,6 +168,8 @@ class BaseModel: self._after_sample_img_hooks = [] self._status_update_hooks = [] self.is_transformer = False + + self.sample_prompts_cache = None # properties for old arch for backwards compatibility @property @@ -484,19 +486,23 @@ class BaseModel: quad_count=4 ) - # encode the prompt ourselves so we can do fun stuff with embeddings - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - conditional_embeds = self.encode_prompt( - gen_config.prompt, gen_config.prompt_2, force_all=True) + if self.sample_prompts_cache is not None: + conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) + unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) + else: + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt( + gen_config.prompt, gen_config.prompt_2, force_all=True) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = True - unconditional_embeds = self.encode_prompt( - gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True - ) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False # allow any manipulations to take place to embeddings gen_config.post_process_embeddings( diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index ff5a68f3..b0558b8c 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -92,6 +92,56 @@ class PromptEmbeds: pe.attention_mask = pe.attention_mask.expand(batch_size, -1) return pe + def save(self, path: str): + """ + Save the prompt embeds to a file. + :param path: The path to save the prompt embeds. + """ + pe = self.clone() + state_dict = {} + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + for i, text_embed in enumerate(pe.text_embeds): + state_dict[f"text_embed_{i}"] = text_embed.cpu() + else: + state_dict["text_embed"] = pe.text_embeds.cpu() + + if pe.pooled_embeds is not None: + state_dict["pooled_embed"] = pe.pooled_embeds.cpu() + if pe.attention_mask is not None: + state_dict["attention_mask"] = pe.attention_mask.cpu() + os.makedirs(os.path.dirname(path), exist_ok=True) + save_file(state_dict, path) + + @classmethod + def load(cls, path: str) -> 'PromptEmbeds': + """ + Load the prompt embeds from a file. + :param path: The path to load the prompt embeds from. + :return: An instance of PromptEmbeds. + """ + state_dict = load_file(path, device='cpu') + text_embeds = [] + pooled_embeds = None + attention_mask = None + for key in sorted(state_dict.keys()): + if key.startswith("text_embed_"): + text_embeds.append(state_dict[key]) + elif key == "text_embed": + text_embeds.append(state_dict[key]) + elif key == "pooled_embed": + pooled_embeds = state_dict[key] + elif key == "attention_mask": + attention_mask = state_dict[key] + pe = cls(None) + pe.text_embeds = text_embeds + if len(text_embeds) == 1: + pe.text_embeds = text_embeds[0] + if pooled_embeds is not None: + pe.pooled_embeds = pooled_embeds + if attention_mask is not None: + pe.attention_mask = attention_mask + return pe + class EncodedPromptPair: def __init__( diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index bbed161e..183cbb8d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -209,6 +209,8 @@ class StableDiffusion: # todo update this based on the model self.is_transformer = False + self.sample_prompts_cache = None + # properties for old arch for backwards compatibility @property def is_xl(self): @@ -1426,18 +1428,22 @@ class StableDiffusion: quad_count=4 ) - # encode the prompt ourselves so we can do fun stuff with embeddings - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) + if self.sample_prompts_cache is not None: + conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) + unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) + else: + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = True - unconditional_embeds = self.encode_prompt( - gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True - ) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False # allow any manipulations to take place to embeddings gen_config.post_process_embeddings( diff --git a/toolkit/unloader.py b/toolkit/unloader.py new file mode 100644 index 00000000..6c45926f --- /dev/null +++ b/toolkit/unloader.py @@ -0,0 +1,63 @@ +import torch +from toolkit.basic import flush +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel + + +class FakeTextEncoder(torch.nn.Module): + def __init__(self, device, dtype): + super().__init__() + # register a dummy parameter to avoid errors in some cases + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + self._device = device + self._dtype = dtype + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This is a fake text encoder and should not be used for inference." + ) + return None + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + def to(self, *args, **kwargs): + return self + + +def unload_text_encoder(model: "BaseModel"): + # unload the text encoder in a way that will work with all models and will not throw errors + # we need to make it appear as a text encoder module without actually having one so all + # to functions and what not will work. + + if model.text_encoder is not None: + if isinstance(model.text_encoder, list): + text_encoder_list = [] + pipe = model.pipeline + + # the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc. + if hasattr(pipe, "text_encoder"): + te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype) + text_encoder_list.append(te) + pipe.text_encoder = te + + i = 2 + while hasattr(pipe, f"text_encoder_{i}"): + te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype) + text_encoder_list.append(te) + setattr(pipe, f"text_encoder_{i}", te) + i += 1 + model.text_encoder = text_encoder_list + else: + # only has a single text encoder + model.text_encoder = FakeTextEncoder() + + flush() diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 016070bc..7ba08e68 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -389,22 +389,40 @@ export default function SimpleJob({ onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} /> - setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} - placeholder="eg. 0.99" - min={0} - /> - -
- setJobConfig(value, 'config.process[0].train.unload_text_encoder')} - /> -
+ {jobConfig.config.process[0].train.ema_config?.use_ema && ( + setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} + placeholder="eg. 0.99" + min={0} + /> + )} + + + { + setJobConfig(value, 'config.process[0].train.unload_text_encoder') + if (value) { + setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); + } + }} + /> + { + setJobConfig(value, 'config.process[0].train.cache_text_embeddings') + if (value) { + setJobConfig(false, 'config.process[0].train.unload_text_encoder') + } + }} + />
@@ -416,21 +434,27 @@ export default function SimpleJob({ onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')} /> - setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')} - placeholder="eg. 1.0" - min={0} - /> - setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} - placeholder="eg. woman" - /> + {jobConfig.config.process[0].train.diff_output_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} + placeholder="eg. woman" + /> + + )}
@@ -524,16 +548,14 @@ export default function SimpleJob({ checked={dataset.is_reg || false} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} /> - { - modelArch?.additionalSections?.includes('datasets.do_i2v') && ( - setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} - docKey="datasets.do_i2v" - /> - ) - } + {modelArch?.additionalSections?.includes('datasets.do_i2v') && ( + setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} + docKey="datasets.do_i2v" + /> + )}
diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index ba24e7e4..91844515 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -66,6 +66,7 @@ export const defaultJobConfig: JobConfig = { weight_decay: 1e-4, }, unload_text_encoder: false, + cache_text_embeddings: false, lr: 0.0001, ema_config: { use_ema: false, diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 88d6c479..d05bda25 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -12,12 +12,12 @@ const docs: { [key: string]: ConfigDoc } = { ), }, - 'gpuids': { + gpuids: { title: 'GPU ID', description: ( <> - This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently. - However, you can start multiple jobs in parallel, each using a different GPU. + This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently. + However, you can start multiple jobs in parallel, each using a different GPU. ), }, @@ -25,17 +25,19 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Trigger Word', description: ( <> - Optional: This will be the word or token used to trigger your concept or character. -
-
- When using a trigger word, - If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have - captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots, - you can use the {'[trigger]'} placeholder in your captions. This will be automatically replaced with your trigger word. -
-
- Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the - {'[trigger]'} placeholder in your test prompts as well. + Optional: This will be the word or token used to trigger your concept or character. +
+
+ When using a trigger word, If your captions do not contain the trigger word, it will be added automatically the + beginning of the caption. If you do not have captions, the caption will become just the trigger word. If you + want to have variable trigger words in your captions to put it in different spots, you can use the{' '} + {'[trigger]'} placeholder in your captions. This will be automatically replaced with your trigger + word. +
+
+ Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger + word manually or use the + {'[trigger]'} placeholder in your test prompts as well. ), }, @@ -43,8 +45,9 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Name or Path', description: ( <> - The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in - diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here. + The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The + folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the + path to an all in one safetensors checkpoint here. ), }, @@ -52,8 +55,8 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Control Dataset', description: ( <> - The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs. - These images are fed as control/input images during training. + The control dataset needs to have files that match the filenames of your training dataset. They should be + matching file pairs. These images are fed as control/input images during training. ), }, @@ -61,16 +64,19 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Number of Frames', description: ( <> - This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame. - If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset. -
-
- It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video. - So you would want all of your videos trimmed to around 5 seconds for best results. -
-
- Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81 - evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast. + This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 + for one frame. If your dataset is only videos, frames will be extracted evenly spaced from the videos in the + dataset. +
+
+ It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames + will result in a 5 second video. So you would want all of your videos trimmed to around 5 seconds for best + results. +
+
+ Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, + will result in 81 evenly spaced frames for each video making the 2 second video appear slow and the 90second + video appear very fast. ), }, @@ -78,9 +84,30 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Do I2V', description: ( <> - For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset - to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image - for the video. If this option is not set, the dataset will be treated as a T2V dataset. + For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this + dataset to be trained as an I2V dataset. This means that the first frame will be extracted from the video and + used as the start image for the video. If this option is not set, the dataset will be treated as a T2V dataset. + + ), + }, + 'train.unload_text_encoder': { + title: 'Unload Text Encoder', + description: ( + <> + Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the + GPU. Captions in for the dataset will be ignored + + ), + }, + 'train.cache_text_embeddings': { + title: 'Cache Text Embeddings', + description: ( + <> + (experimental) +
+ Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The + text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt + such as trigger words, caption dropout, etc. ), }, diff --git a/ui/src/types.ts b/ui/src/types.ts index 08b22c8a..7834b9c2 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -110,6 +110,7 @@ export interface TrainConfig { ema_config?: EMAConfig; dtype: string; unload_text_encoder: boolean; + cache_text_embeddings: boolean; optimizer_params: { weight_decay: number; }; diff --git a/version.py b/version.py index be340cc3..0e6b1dd4 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.3.18" \ No newline at end of file +VERSION = "0.4.0" \ No newline at end of file