diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d2aa4b24..912c241f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -814,6 +814,9 @@ class DatasetConfig: self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc if self.control_path == '': self.control_path = None + + # color for transparent reigon of control images with transparency + self.control_transparent_color: List[int] = kwargs.get('control_transparent_color', [0, 0, 0]) # inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will # be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 69df9cbd..6d231f49 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -876,8 +876,19 @@ class ControlFileItemDTOMixin: for control_path in control_path_list: try: - img = Image.open(control_path).convert('RGB') + img = Image.open(control_path) img = exif_transpose(img) + + if img.mode in ("RGBA", "LA"): + # Create a background with the specified transparent color + transparent_color = tuple(self.dataset_config.control_transparent_color) + background = Image.new("RGB", img.size, transparent_color) + # Paste the image on top using its alpha channel as mask + background.paste(img, mask=img.getchannel("A")) + img = background + else: + # Already no alpha channel + img = img.convert("RGB") except Exception as e: print_acc(f"Error: {e}") print_acc(f"Error loading image: {control_path}") diff --git a/version.py b/version.py index caa84b83..9829a106 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.7" \ No newline at end of file +VERSION = "0.5.8" \ No newline at end of file