mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bug fixes, work on maing IP adapters more customizable.
This commit is contained in:
@@ -56,6 +56,30 @@ transforms_dict = {
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
|
||||
|
||||
def standardize_images(images):
|
||||
"""
|
||||
Standardize the given batch of images using the specified mean and std.
|
||||
Expects values of 0 - 1
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): A batch of images in the shape of (N, C, H, W),
|
||||
where N is the number of images, C is the number of channels,
|
||||
H is the height, and W is the width.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Standardized images.
|
||||
"""
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# Define the normalization transform
|
||||
normalize = transforms.Normalize(mean=mean, std=std)
|
||||
|
||||
# Apply normalization to each image in the batch
|
||||
standardized_images = torch.stack([normalize(img) for img in images])
|
||||
|
||||
return standardized_images
|
||||
|
||||
def clean_caption(caption):
|
||||
# remove any newlines
|
||||
caption = caption.replace('\n', ', ')
|
||||
@@ -520,8 +544,13 @@ class ControlFileItemDTOMixin:
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
|
||||
self.control_tensor = transforms.ToTensor()(img)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.control_tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
self.control_tensor = transform(img)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
@@ -624,6 +653,8 @@ class AugmentationFileItemDTOMixin:
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
# self.augmentations: Union[None, List[Augments]] = None
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.aug_transform: Union[None, A.Compose] = None
|
||||
self.aug_replay_spatial_transforms = None
|
||||
self.build_augmentation_transform()
|
||||
|
||||
def build_augmentation_transform(self: 'FileItemDTO'):
|
||||
@@ -643,7 +674,8 @@ class AugmentationFileItemDTOMixin:
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.aug_transform = A.Compose(augmentation_list)
|
||||
# add additional targets so we can augment the control image
|
||||
self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'})
|
||||
|
||||
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
|
||||
|
||||
@@ -659,7 +691,17 @@ class AugmentationFileItemDTOMixin:
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.aug_transform(image=open_cv_image)["image"]
|
||||
transformed = self.aug_transform(image=open_cv_image)
|
||||
augmented = transformed["image"]
|
||||
|
||||
# save just the spatial transforms for controls and masks
|
||||
augmented_params = transformed["replay"]
|
||||
spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop',
|
||||
'ElasticTransform', 'GridDistortion', 'OpticalDistortion']
|
||||
# only store the spatial transforms
|
||||
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
|
||||
|
||||
self.aug_replay_spatial_transforms = augmented_params
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
@@ -671,6 +713,38 @@ class AugmentationFileItemDTOMixin:
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
# augment control images spatially consistent with transforms done to the main image
|
||||
def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ):
|
||||
if self.aug_replay_spatial_transforms is None:
|
||||
# no transforms
|
||||
return transform(img)
|
||||
|
||||
# save colorspace to convert back to
|
||||
colorspace = img.mode
|
||||
|
||||
# convert to rgb
|
||||
img = img.convert('RGB')
|
||||
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# Replay transforms
|
||||
transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image)
|
||||
augmented = transformed["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
# convert back to original colorspace
|
||||
augmented = augmented.convert(colorspace)
|
||||
|
||||
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
|
||||
return augmented_tensor
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.unaugmented_tensor = None
|
||||
|
||||
@@ -760,7 +834,13 @@ class MaskFileItemDTOMixin:
|
||||
else:
|
||||
raise Exception("Mask images not supported for non-bucket datasets")
|
||||
|
||||
self.mask_tensor = transforms.ToTensor()(img)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.mask_tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
self.mask_tensor = transform(img)
|
||||
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
|
||||
# convert to grayscale
|
||||
|
||||
@@ -776,12 +856,7 @@ class UnconditionalFileItemDTOMixin:
|
||||
self.unconditional_path: Union[str, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latent: Union[torch.Tensor, None] = None
|
||||
self.unconditional_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.unconditional_transforms = self.dataloader_transforms
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
|
||||
if dataset_config.unconditional_path is not None:
|
||||
@@ -835,7 +910,10 @@ class UnconditionalFileItemDTOMixin:
|
||||
else:
|
||||
raise Exception("Unconditional images are not supported for non-bucket datasets")
|
||||
|
||||
self.unconditional_tensor = self.unconditional_transforms(img)
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms)
|
||||
else:
|
||||
self.unconditional_tensor = self.unconditional_transforms(img)
|
||||
|
||||
def cleanup_unconditional(self: 'FileItemDTO'):
|
||||
self.unconditional_tensor = None
|
||||
|
||||
Reference in New Issue
Block a user