mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 12:23:57 +00:00
Allow for alpha to be used as a mask
This commit is contained in:
@@ -235,6 +235,7 @@ class DatasetConfig:
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
|
||||
@@ -7,6 +7,7 @@ import random
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
@@ -293,11 +294,21 @@ class ImageProcessingDTOMixin:
|
||||
self.load_mask_image()
|
||||
return
|
||||
try:
|
||||
img = Image.open(self.path).convert('RGB')
|
||||
img = Image.open(self.mask_path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.path}")
|
||||
|
||||
if self.use_alpha_as_mask:
|
||||
# we do this to make sure it does not replace the alpha with another color
|
||||
# we want the image just without the alpha channel
|
||||
np_img = np.array(img)
|
||||
# strip off alpha
|
||||
np_img = np_img[:, :, :3]
|
||||
img = Image.fromarray(np_img)
|
||||
|
||||
img = img.convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
@@ -443,11 +454,16 @@ class MaskFileItemDTOMixin:
|
||||
self.has_mask_image = False
|
||||
self.mask_path: Union[str, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.use_alpha_as_mask: bool = False
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.mask_min_value = dataset_config.mask_min_value
|
||||
if dataset_config.mask_path is not None:
|
||||
if dataset_config.alpha_mask:
|
||||
self.use_alpha_as_mask = True
|
||||
self.mask_path = kwargs.get('path', None)
|
||||
self.has_mask_image = True
|
||||
elif dataset_config.mask_path is not None:
|
||||
# find the control image path
|
||||
mask_path = dataset_config.mask_path
|
||||
mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
@@ -460,11 +476,21 @@ class MaskFileItemDTOMixin:
|
||||
|
||||
def load_mask_image(self: 'FileItemDTO'):
|
||||
try:
|
||||
img = Image.open(self.mask_path).convert('RGB')
|
||||
img = Image.open(self.mask_path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.mask_path}")
|
||||
|
||||
if self.use_alpha_as_mask:
|
||||
# pipeline expectws an rgb image so we need to put alpha in all channels
|
||||
np_img = np.array(img)
|
||||
np_img[:, :, :3] = np_img[:, :, 3:]
|
||||
|
||||
np_img = np_img[:, :, :3]
|
||||
img = Image.fromarray(np_img)
|
||||
|
||||
img = img.convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
|
||||
Reference in New Issue
Block a user