Bug fixes, work on maing IP adapters more customizable.

This commit is contained in:
Jaret Burkett
2023-12-24 08:32:39 -07:00
parent 7703e3a15e
commit 0f8daa5612
7 changed files with 243 additions and 36 deletions

View File

@@ -56,6 +56,30 @@ transforms_dict = {
caption_ext_list = ['txt', 'json', 'caption']
def standardize_images(images):
"""
Standardize the given batch of images using the specified mean and std.
Expects values of 0 - 1
Args:
images (torch.Tensor): A batch of images in the shape of (N, C, H, W),
where N is the number of images, C is the number of channels,
H is the height, and W is the width.
Returns:
torch.Tensor: Standardized images.
"""
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# Define the normalization transform
normalize = transforms.Normalize(mean=mean, std=std)
# Apply normalization to each image in the batch
standardized_images = torch.stack([normalize(img) for img in images])
return standardized_images
def clean_caption(caption):
# remove any newlines
caption = caption.replace('\n', ', ')
@@ -520,8 +544,13 @@ class ControlFileItemDTOMixin:
))
else:
raise Exception("Control images not supported for non-bucket datasets")
self.control_tensor = transforms.ToTensor()(img)
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.control_tensor = self.augment_spatial_control(img, transform=transform)
else:
self.control_tensor = transform(img)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None
@@ -624,6 +653,8 @@ class AugmentationFileItemDTOMixin:
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.aug_transform: Union[None, A.Compose] = None
self.aug_replay_spatial_transforms = None
self.build_augmentation_transform()
def build_augmentation_transform(self: 'FileItemDTO'):
@@ -643,7 +674,8 @@ class AugmentationFileItemDTOMixin:
# add the method to the list
augmentation_list.append(method(**aug.params))
self.aug_transform = A.Compose(augmentation_list)
# add additional targets so we can augment the control image
self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'})
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
@@ -659,7 +691,17 @@ class AugmentationFileItemDTOMixin:
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.aug_transform(image=open_cv_image)["image"]
transformed = self.aug_transform(image=open_cv_image)
augmented = transformed["image"]
# save just the spatial transforms for controls and masks
augmented_params = transformed["replay"]
spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop',
'ElasticTransform', 'GridDistortion', 'OpticalDistortion']
# only store the spatial transforms
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
self.aug_replay_spatial_transforms = augmented_params
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
@@ -671,6 +713,38 @@ class AugmentationFileItemDTOMixin:
return augmented_tensor
# augment control images spatially consistent with transforms done to the main image
def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ):
if self.aug_replay_spatial_transforms is None:
# no transforms
return transform(img)
# save colorspace to convert back to
colorspace = img.mode
# convert to rgb
img = img.convert('RGB')
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# Replay transforms
transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image)
augmented = transformed["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
# convert back to original colorspace
augmented = augmented.convert(colorspace)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def cleanup_control(self: 'FileItemDTO'):
self.unaugmented_tensor = None
@@ -760,7 +834,13 @@ class MaskFileItemDTOMixin:
else:
raise Exception("Mask images not supported for non-bucket datasets")
self.mask_tensor = transforms.ToTensor()(img)
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.mask_tensor = self.augment_spatial_control(img, transform=transform)
else:
self.mask_tensor = transform(img)
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
# convert to grayscale
@@ -776,12 +856,7 @@ class UnconditionalFileItemDTOMixin:
self.unconditional_path: Union[str, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latent: Union[torch.Tensor, None] = None
self.unconditional_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.unconditional_transforms = self.dataloader_transforms
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.unconditional_path is not None:
@@ -835,7 +910,10 @@ class UnconditionalFileItemDTOMixin:
else:
raise Exception("Unconditional images are not supported for non-bucket datasets")
self.unconditional_tensor = self.unconditional_transforms(img)
if self.aug_replay_spatial_transforms:
self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms)
else:
self.unconditional_tensor = self.unconditional_transforms(img)
def cleanup_unconditional(self: 'FileItemDTO'):
self.unconditional_tensor = None