diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 927286e4..90cedbdb 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -392,7 +392,7 @@ class BaseSDTrainProcess(BaseTrainProcess): def load_training_state_from_metadata(self, path): meta = load_metadata_from_safetensors(path) # if 'training_info' in Orderdict keys - if 'training_info' in meta and 'step' in meta['training_info']: + if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: self.step_num = meta['training_info']['step'] self.start_step = self.step_num print(f"Found step {self.step_num} in metadata, starting from there") @@ -796,6 +796,10 @@ class BaseSDTrainProcess(BaseTrainProcess): else: self.params.append(param) + if self.train_config.start_step is not None: + self.step_num = self.train_config.start_step + self.start_step = self.step_num + optimizer_type = self.train_config.optimizer.lower() optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr, optimizer_params=self.train_config.optimizer_params) diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index cc6800f3..492357eb 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -38,6 +38,8 @@ dataset_config = DatasetConfig( default_caption='default', buckets=True, bucket_tolerance=bucket_tolerance, + augments=['ColorJitter', 'RandomEqualize'], + ) dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) @@ -48,22 +50,22 @@ for batch in dataloader: batch: 'DataLoaderBatchDTO' img_batch = batch.tensor - # chunks = torch.chunk(img_batch, batch_size, dim=0) - # # put them so they are size by side - # big_img = torch.cat(chunks, dim=3) - # big_img = big_img.squeeze(0) - # - # min_val = big_img.min() - # max_val = big_img.max() - # - # big_img = (big_img / 2 + 0.5).clamp(0, 1) - # - # # convert to image - # img = transforms.ToPILImage()(big_img) - # - # show_img(img) - # - # time.sleep(1.0) + chunks = torch.chunk(img_batch, batch_size, dim=0) + # put them so they are size by side + big_img = torch.cat(chunks, dim=3) + big_img = big_img.squeeze(0) + + min_val = big_img.min() + max_val = big_img.max() + + big_img = (big_img / 2 + 0.5).clamp(0, 1) + + # convert to image + img = transforms.ToPILImage()(big_img) + + show_img(img) + + time.sleep(1.0) cv2.destroyAllWindows() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index bd76b7b4..7e907ff1 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -112,6 +112,7 @@ class TrainConfig: self.weight_jitter = kwargs.get('weight_jitter', 0.0) self.merge_network_on_save = kwargs.get('merge_network_on_save', False) self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) + self.start_step = kwargs.get('start_step', None) class ModelConfig: @@ -221,12 +222,18 @@ class DatasetConfig: self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_y', False) + self.augments: List[str] = kwargs.get('augments', []) # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) + if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk): + print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False") + self.cache_latents = False + self.cache_latents_to_disk = False + # legacy compatability legacy_caption_type = kwargs.get('caption_type', None) if legacy_caption_type: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 791acb3c..842375f0 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -49,6 +49,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_x', False) + self.augments: List[str] = self.dataset_config.augments self.network_weight: float = self.dataset_config.network_weight self.is_reg = self.dataset_config.is_reg diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index ba9c95d0..d39e7f03 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -25,10 +25,15 @@ if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset from toolkit.data_transfer_object.data_loader import FileItemDTO - # def get_associated_caption_from_img_path(img_path): +transforms_dict = { + 'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01), + 'RandomEqualize': transforms.RandomEqualize(p=0.2), +} + + class CaptionMixin: def get_caption_item(self: 'AiToolkitDataset', index): if not hasattr(self, 'caption_type'): @@ -287,6 +292,12 @@ class ImageProcessingDTOMixin: img = transforms.CenterCrop(min_img_size)(img) img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC) + if self.augments is not None and len(self.augments) > 0: + # do augmentations + for augment in self.augments: + if augment in transforms_dict: + img = transforms_dict[augment](img) + if transform: img = transform(img)