mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Bugfixes. Added small augmentations to dataloader. Will switch to abluminations soon though. Added ability to adjust step count on start to override what is in the file
This commit is contained in:
@@ -25,10 +25,15 @@ if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
|
||||
|
||||
transforms_dict = {
|
||||
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
|
||||
'RandomEqualize': transforms.RandomEqualize(p=0.2),
|
||||
}
|
||||
|
||||
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||
if not hasattr(self, 'caption_type'):
|
||||
@@ -287,6 +292,12 @@ class ImageProcessingDTOMixin:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
|
||||
|
||||
if self.augments is not None and len(self.augments) > 0:
|
||||
# do augmentations
|
||||
for augment in self.augments:
|
||||
if augment in transforms_dict:
|
||||
img = transforms_dict[augment](img)
|
||||
|
||||
if transform:
|
||||
img = transform(img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user