diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1787f0da..03944e6c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -304,19 +304,11 @@ class SDTrainer(BaseSDTrainProcess): # handle unload text encoder if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + print_acc("Caching embeddings and unloading text encoder") with torch.no_grad(): if self.train_config.train_text_encoder: raise ValueError("Cannot unload text encoder if training text encoder") # cache embeddings - - print_acc("\n***** UNLOADING TEXT ENCODER *****") - if self.is_caching_text_embeddings: - print_acc("Embeddings cached to disk. We dont need the text encoder anymore") - else: - print_acc("This will train only with a blank prompt or trigger word, if set") - print_acc("If this is not what you want, remove the unload_text_encoder flag") - print_acc("***********************************") - print_acc("") self.sd.text_encoder_to(self.device_torch) encode_kwargs = {} if self.sd.encode_control_in_text_embeddings: @@ -332,6 +324,15 @@ class SDTrainer(BaseSDTrainProcess): self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) self.cache_sample_prompts() + + print_acc("\n***** UNLOADING TEXT ENCODER *****") + if self.is_caching_text_embeddings: + print_acc("Embeddings cached to disk. We dont need the text encoder anymore") + else: + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("***********************************") + print_acc("") # unload the text encoder if self.is_caching_text_embeddings: