diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5b615f27..e5fa9e87 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -76,6 +76,9 @@ class SDTrainer(BaseSDTrainProcess): return org_unscale_grads(optimizer, inv_scale, found_inf, True) self.scaler._unscale_grads_ = _unscale_grads_replacer + self.cached_blank_embeds: Optional[PromptEmbeds] = None + self.cached_trigger_embeds: Optional[PromptEmbeds] = None + def before_model_load(self): pass @@ -155,6 +158,28 @@ class SDTrainer(BaseSDTrainProcess): # single prompt self.negative_prompt_pool = [self.train_config.negative_prompt] + # handle unload text encoder + if self.train_config.unload_text_encoder: + with torch.no_grad(): + if self.train_config.train_text_encoder: + raise ValueError("Cannot unload text encoder if training text encoder") + # cache embeddings + + print("\n***** UNLOADING TEXT ENCODER *****") + print("This will train only with a blank prompt or trigger word, if set") + print("If this is not what you want, remove the unload_text_encoder flag") + print("***********************************") + print("") + self.sd.text_encoder_to(self.device_torch) + self.cached_blank_embeds = self.sd.encode_prompt("") + if self.trigger_word is not None: + self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) + + # move back to cpu + self.sd.text_encoder_to('cpu') + flush() + + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): # to process turbo learning, we make one big step from our current timestep to the end # we then denoise the prediction on that remaining step and target our loss to our target latents @@ -799,6 +824,8 @@ class SDTrainer(BaseSDTrainProcess): was_adapter_active = self.adapter.is_active self.adapter.is_active = False + if self.train_config.unload_text_encoder: + raise ValueError("Prior predictions currently do not support unloading text encoder") # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) @@ -1183,7 +1210,30 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_prompt'): unconditional_embeds = None - if grad_on_text_encoder: + if self.train_config.unload_text_encoder: + with torch.set_grad_enabled(False): + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + if self.cached_trigger_embeds is not None and not is_reg: + embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + conditional_embeds = concat_prompt_embeds( + [embeds_to_use] * noisy_latents.shape[0] + ) + if self.train_config.do_cfg: + unconditional_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + unconditional_embeds = concat_prompt_embeds( + [unconditional_embeds] * noisy_latents.shape[0] + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + elif grad_on_text_encoder: with torch.set_grad_enabled(True): if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 6244b46c..8c480147 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -174,6 +174,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_adapter=is_training_adapter, train_embedding=self.embed_config is not None, train_refiner=self.train_config.train_refiner, + unload_text_encoder=self.train_config.unload_text_encoder ) # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) @@ -1750,6 +1751,11 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.free_u: self.sd.pipeline.disable_freeu() self.sample(self.step_num) + if self.train_config.unload_text_encoder: + # make sure the text encoder is unloaded + self.sd.text_encoder_to('cpu') + flush() + self.ensure_params_requires_grad() self.progress_bar.unpause() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a8eb35fb..a1b0dd42 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -377,6 +377,10 @@ class TrainConfig: self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) self.disable_sampling = kwargs.get('disable_sampling', False) + # will cache a blank prompt or the trigger word, and unload the text encoder to cpu + # will make training faster and use less vram + self.unload_text_encoder = kwargs.get('unload_text_encoder', False) + class ModelConfig: def __init__(self, **kwargs): diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index 7f82f37c..2ee1d555 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -40,6 +40,7 @@ def get_train_sd_device_state_preset( train_adapter: bool = False, train_embedding: bool = False, train_refiner: bool = False, + unload_text_encoder: bool = False, ): preset = copy.deepcopy(empty_preset) if not cached_latents: @@ -88,4 +89,9 @@ def get_train_sd_device_state_preset( preset['unet']['device'] = device preset['text_encoder']['device'] = device + if unload_text_encoder: + preset['text_encoder']['training'] = False + preset['text_encoder']['requires_grad'] = False + preset['text_encoder']['device'] = 'cpu' + return preset diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b07fbc68..dd9c6062 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2635,3 +2635,10 @@ class StableDiffusion: } self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs)