From 6280284d8bcc830c29ec63ec66bef807faa5e4e6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 16 Nov 2023 20:26:11 -0700 Subject: [PATCH] Fixed cleanup of emebddings. --- jobs/process/BaseSDTrainProcess.py | 15 ++++++++++++--- toolkit/basic.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8a06aa45..42d01786 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -175,6 +175,7 @@ class BaseSDTrainProcess(BaseTrainProcess): return generate_image_config_list def sample(self, step=None, is_first=False): + flush() sample_folder = os.path.join(self.save_root, 'samples') gen_img_config_list = [] @@ -284,6 +285,13 @@ class BaseSDTrainProcess(BaseTrainProcess): safetensors_files = [f for f in items if f.endswith('.safetensors')] pt_files = [f for f in items if f.endswith('.pt')] directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')] + embed_files = [] + # do embedding files + if self.embed_config is not None: + embed_pattern = f"{self.embed_config.trigger}_*" + embed_items = glob.glob(os.path.join(self.save_root, embed_pattern)) + # will end in safetensors or pt + embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')] # Sort the lists by creation time if they are not empty if safetensors_files: @@ -292,6 +300,8 @@ class BaseSDTrainProcess(BaseTrainProcess): pt_files.sort(key=os.path.getctime) if directories: directories.sort(key=os.path.getctime) + if embed_files: + embed_files.sort(key=os.path.getctime) # Combine and sort the lists combined_items = safetensors_files + directories + pt_files @@ -302,10 +312,9 @@ class BaseSDTrainProcess(BaseTrainProcess): :-self.save_config.max_step_saves_to_keep] if safetensors_files else [] pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else [] directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else [] - combined_items_to_remove = combined_items[ - :-self.save_config.max_step_saves_to_keep] if combined_items else [] + embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else [] - items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove # remove all but the latest max_step_saves_to_keep # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep] diff --git a/toolkit/basic.py b/toolkit/basic.py index 9a64ca11..6a70bf61 100644 --- a/toolkit/basic.py +++ b/toolkit/basic.py @@ -13,6 +13,21 @@ def flush(garbage_collect=True): gc.collect() +def get_mean_std(tensor): + if len(tensor.shape) == 3: + tensor = tensor.unsqueeze(0) + elif len(tensor.shape) != 4: + raise Exception("Expected tensor of shape (batch_size, channels, width, height)") + mean, variance = torch.mean( + tensor, dim=[2, 3], keepdim=True + ), torch.var( + tensor, dim=[2, 3], + keepdim=True + ) + std = torch.sqrt(variance + 1e-5) + return mean, std + + def adain(content_features, style_features): # Assumes that the content and style features are of shape (batch_size, channels, width, height)