From 38e441a29ccee840f0ec48ef197db28f06dbc4e6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 12 Oct 2023 21:02:47 -0600 Subject: [PATCH] allow flipping for point of interesting autocropping. allow num repeats. Fixed some bugs with new free u --- jobs/process/BaseSDTrainProcess.py | 6 ++++++ toolkit/config_modules.py | 3 ++- toolkit/data_loader.py | 4 ++++ toolkit/dataloader_mixins.py | 15 ++++++++++++--- toolkit/train_tools.py | 3 +++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1a08dbc3..541e9b03 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -928,6 +928,8 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): + if self.train_config.free_u: + self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) self.progress_bar.unpause() with torch.no_grad(): # if is even step and we have a reg dataset, use that @@ -1000,6 +1002,8 @@ class BaseSDTrainProcess(BaseTrainProcess): if is_sample_step: self.progress_bar.pause() # print above the progress bar + if self.train_config.free_u: + self.sd.pipeline.disable_freeu() self.sample(self.step_num) self.progress_bar.unpause() @@ -1047,6 +1051,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() self.progress_bar.close() + if self.train_config.free_u: + self.sd.pipeline.disable_freeu() self.sample(self.step_num + 1) print("") self.save() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b4ca167f..8a616805 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -120,6 +120,7 @@ class TrainConfig: self.merge_network_on_save = kwargs.get('merge_network_on_save', False) self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) self.start_step = kwargs.get('start_step', None) + self.free_u = kwargs.get('free_u', False) class ModelConfig: @@ -239,7 +240,7 @@ class DatasetConfig: self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes - + self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 84b5d229..7f5cfb62 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -377,6 +377,10 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): # keys are file paths file_list = list(self.caption_dict.keys()) + if self.dataset_config.num_repeats > 1: + # repeat the list + file_list = file_list * self.dataset_config.num_repeats + # this might take a while print(f" - Preprocessing image dimensions") bad_count = 0 diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3bd786c6..00cb536e 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -156,8 +156,9 @@ class BucketsMixin: # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution max_scale_factor = max(width_scale_factor, height_scale_factor) - file_item.scale_to_width = int(width * max_scale_factor) - file_item.scale_to_height = int(height * max_scale_factor) + # round up + file_item.scale_to_width = int(math.ceil(width * max_scale_factor)) + file_item.scale_to_height = int(math.ceil(height * max_scale_factor)) file_item.crop_height = bucket_resolution["height"] file_item.crop_width = bucket_resolution["width"] @@ -294,7 +295,7 @@ class ImageProcessingDTOMixin: self.load_mask_image() return try: - img = Image.open(self.mask_path) + img = Image.open(self.path) img = exif_transpose(img) except Exception as e: print(f"Error: {e}") @@ -583,6 +584,14 @@ class PoiFileItemDTOMixin: self.poi_width = int(poi['width']) self.poi_height = int(poi['height']) + # handle flipping + if kwargs.get('flip_x', False): + # flip the poi + self.poi_x = self.width - self.poi_x - self.poi_width + if kwargs.get('flip_y', False): + # flip the poi + self.poi_y = self.height - self.poi_y - self.poi_height + def setup_poi_bucket(self: 'FileItemDTO'): # we are using poi, so we need to calculate the bucket based on the poi diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 5bee2f6b..1e69d798 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -5,6 +5,9 @@ import os import time from typing import TYPE_CHECKING, Union import sys + +from torch.cuda.amp import GradScaler + from toolkit.paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT)