diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b23b762a..b4ca167f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -235,6 +235,7 @@ class DatasetConfig: self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) self.control_path: str = kwargs.get('control_path', None) # depth maps, etc + self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 6cbab537..3bd786c6 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -7,6 +7,7 @@ import random from collections import OrderedDict from typing import TYPE_CHECKING, List, Dict, Union +import numpy as np import torch from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -293,11 +294,21 @@ class ImageProcessingDTOMixin: self.load_mask_image() return try: - img = Image.open(self.path).convert('RGB') + img = Image.open(self.mask_path) img = exif_transpose(img) except Exception as e: print(f"Error: {e}") print(f"Error loading image: {self.path}") + + if self.use_alpha_as_mask: + # we do this to make sure it does not replace the alpha with another color + # we want the image just without the alpha channel + np_img = np.array(img) + # strip off alpha + np_img = np_img[:, :, :3] + img = Image.fromarray(np_img) + + img = img.convert('RGB') w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match @@ -443,11 +454,16 @@ class MaskFileItemDTOMixin: self.has_mask_image = False self.mask_path: Union[str, None] = None self.mask_tensor: Union[torch.Tensor, None] = None + self.use_alpha_as_mask: bool = False dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) self.mask_min_value = dataset_config.mask_min_value - if dataset_config.mask_path is not None: + if dataset_config.alpha_mask: + self.use_alpha_as_mask = True + self.mask_path = kwargs.get('path', None) + self.has_mask_image = True + elif dataset_config.mask_path is not None: # find the control image path - mask_path = dataset_config.mask_path + mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask # we are using control images img_path = kwargs.get('path', None) img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] @@ -460,11 +476,21 @@ class MaskFileItemDTOMixin: def load_mask_image(self: 'FileItemDTO'): try: - img = Image.open(self.mask_path).convert('RGB') + img = Image.open(self.mask_path) img = exif_transpose(img) except Exception as e: print(f"Error: {e}") print(f"Error loading image: {self.mask_path}") + + if self.use_alpha_as_mask: + # pipeline expectws an rgb image so we need to put alpha in all channels + np_img = np.array(img) + np_img[:, :, :3] = np_img[:, :, 3:] + + np_img = np_img[:, :, :3] + img = Image.fromarray(np_img) + + img = img.convert('RGB') w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match