Bugfixes. Added small augmentations to dataloader. Will switch to abluminations soon though. Added ability to adjust step count on start to override what is in the file

This commit is contained in:
Jaret Burkett
2023-09-20 05:30:10 -06:00
parent 0f105690cc
commit 19255cdc7c
5 changed files with 43 additions and 18 deletions

View File

@@ -112,6 +112,7 @@ class TrainConfig:
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None)
class ModelConfig:
@@ -221,12 +222,18 @@ class DatasetConfig:
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
self.cache_latents = False
self.cache_latents_to_disk = False
# legacy compatability
legacy_caption_type = kwargs.get('caption_type', None)
if legacy_caption_type:

View File

@@ -49,6 +49,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.augments: List[str] = self.dataset_config.augments
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg

View File

@@ -25,10 +25,15 @@ if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
# def get_associated_caption_from_img_path(img_path):
transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
'RandomEqualize': transforms.RandomEqualize(p=0.2),
}
class CaptionMixin:
def get_caption_item(self: 'AiToolkitDataset', index):
if not hasattr(self, 'caption_type'):
@@ -287,6 +292,12 @@ class ImageProcessingDTOMixin:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
if self.augments is not None and len(self.augments) > 0:
# do augmentations
for augment in self.augments:
if augment in transforms_dict:
img = transforms_dict[augment](img)
if transform:
img = transform(img)