mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added ability to add masks to dataloader and sd trainer to adjust weight of image
This commit is contained in:
@@ -98,18 +98,31 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||||
network_weight_list = batch.get_network_weight_list()
|
network_weight_list = batch.get_network_weight_list()
|
||||||
|
|
||||||
adapter_images = None
|
with torch.no_grad():
|
||||||
sigmas = None
|
adapter_images = None
|
||||||
if self.adapter:
|
sigmas = None
|
||||||
# todo move this to data loader
|
if self.adapter:
|
||||||
if batch.control_tensor is not None:
|
# todo move this to data loader
|
||||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
if batch.control_tensor is not None:
|
||||||
else:
|
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||||
adapter_images = self.get_adapter_images(batch)
|
else:
|
||||||
# not 100% sure what this does. But they do it here
|
adapter_images = self.get_adapter_images(batch)
|
||||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
# not 100% sure what this does. But they do it here
|
||||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||||
|
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||||
|
|
||||||
|
mask_multiplier = 1.0
|
||||||
|
if batch.mask_tensor is not None:
|
||||||
|
# upsampling no supported for bfloat16
|
||||||
|
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||||
|
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||||
|
mask_multiplier = torch.nn.functional.interpolate(
|
||||||
|
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||||
|
)
|
||||||
|
# expand to match latents
|
||||||
|
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||||
|
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
|
||||||
# flush()
|
# flush()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
@@ -188,6 +201,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
# multiply by our mask
|
||||||
|
loss = loss * mask_multiplier
|
||||||
|
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ class DatasetConfig:
|
|||||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||||
self.augments: List[str] = kwargs.get('augments', [])
|
self.augments: List[str] = kwargs.get('augments', [])
|
||||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||||
|
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||||
|
|
||||||
# cache latents will store them in memory
|
# cache latents will store them in memory
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
|
|||||||
|
|
||||||
from toolkit import image_utils
|
from toolkit import image_utils
|
||||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
|
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
@@ -27,6 +27,7 @@ class FileItemDTO(
|
|||||||
CaptionProcessingDTOMixin,
|
CaptionProcessingDTOMixin,
|
||||||
ImageProcessingDTOMixin,
|
ImageProcessingDTOMixin,
|
||||||
ControlFileItemDTOMixin,
|
ControlFileItemDTOMixin,
|
||||||
|
MaskFileItemDTOMixin,
|
||||||
PoiFileItemDTOMixin,
|
PoiFileItemDTOMixin,
|
||||||
ArgBreakMixin,
|
ArgBreakMixin,
|
||||||
):
|
):
|
||||||
@@ -67,6 +68,7 @@ class FileItemDTO(
|
|||||||
self.tensor = None
|
self.tensor = None
|
||||||
self.cleanup_latent()
|
self.cleanup_latent()
|
||||||
self.cleanup_control()
|
self.cleanup_control()
|
||||||
|
self.cleanup_mask()
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderBatchDTO:
|
class DataLoaderBatchDTO:
|
||||||
@@ -76,6 +78,8 @@ class DataLoaderBatchDTO:
|
|||||||
is_latents_cached = self.file_items[0].is_latent_cached
|
is_latents_cached = self.file_items[0].is_latent_cached
|
||||||
self.tensor: Union[torch.Tensor, None] = None
|
self.tensor: Union[torch.Tensor, None] = None
|
||||||
self.latents: Union[torch.Tensor, None] = None
|
self.latents: Union[torch.Tensor, None] = None
|
||||||
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
|
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||||
if not is_latents_cached:
|
if not is_latents_cached:
|
||||||
# only return a tensor if latents are not cached
|
# only return a tensor if latents are not cached
|
||||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||||
@@ -100,6 +104,21 @@ class DataLoaderBatchDTO:
|
|||||||
else:
|
else:
|
||||||
control_tensors.append(x.control_tensor)
|
control_tensors.append(x.control_tensor)
|
||||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||||
|
|
||||||
|
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||||
|
# find one to use as a base
|
||||||
|
base_mask_tensor = None
|
||||||
|
for x in self.file_items:
|
||||||
|
if x.mask_tensor is not None:
|
||||||
|
base_mask_tensor = x.mask_tensor
|
||||||
|
break
|
||||||
|
mask_tensors = []
|
||||||
|
for x in self.file_items:
|
||||||
|
if x.mask_tensor is None:
|
||||||
|
mask_tensors.append(torch.zeros_like(base_mask_tensor))
|
||||||
|
else:
|
||||||
|
mask_tensors.append(x.mask_tensor)
|
||||||
|
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from toolkit.buckets import get_bucket_for_image_size
|
|||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from PIL import Image
|
from PIL import Image, ImageFilter
|
||||||
from PIL.ImageOps import exif_transpose
|
from PIL.ImageOps import exif_transpose
|
||||||
|
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
@@ -288,6 +288,8 @@ class ImageProcessingDTOMixin:
|
|||||||
self.get_latent()
|
self.get_latent()
|
||||||
if self.has_control_image:
|
if self.has_control_image:
|
||||||
self.load_control_image()
|
self.load_control_image()
|
||||||
|
if self.has_mask_image:
|
||||||
|
self.load_mask_image()
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
img = Image.open(self.path).convert('RGB')
|
img = Image.open(self.path).convert('RGB')
|
||||||
@@ -363,6 +365,8 @@ class ImageProcessingDTOMixin:
|
|||||||
self.tensor = img
|
self.tensor = img
|
||||||
if self.has_control_image:
|
if self.has_control_image:
|
||||||
self.load_control_image()
|
self.load_control_image()
|
||||||
|
if self.has_mask_image:
|
||||||
|
self.load_mask_image()
|
||||||
|
|
||||||
|
|
||||||
class ControlFileItemDTOMixin:
|
class ControlFileItemDTOMixin:
|
||||||
@@ -430,6 +434,79 @@ class ControlFileItemDTOMixin:
|
|||||||
self.control_tensor = None
|
self.control_tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFileItemDTOMixin:
|
||||||
|
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||||
|
if hasattr(super(), '__init__'):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.has_mask_image = False
|
||||||
|
self.mask_path: Union[str, None] = None
|
||||||
|
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||||
|
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
|
if dataset_config.mask_path is not None:
|
||||||
|
# find the control image path
|
||||||
|
mask_path = dataset_config.mask_path
|
||||||
|
# we are using control images
|
||||||
|
img_path = kwargs.get('path', None)
|
||||||
|
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||||
|
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
for ext in img_ext_list:
|
||||||
|
if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)):
|
||||||
|
self.mask_path = os.path.join(mask_path, file_name_no_ext + ext)
|
||||||
|
self.has_mask_image = True
|
||||||
|
break
|
||||||
|
|
||||||
|
def load_mask_image(self: 'FileItemDTO'):
|
||||||
|
try:
|
||||||
|
img = Image.open(self.mask_path).convert('RGB')
|
||||||
|
img = exif_transpose(img)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
print(f"Error loading image: {self.mask_path}")
|
||||||
|
w, h = img.size
|
||||||
|
if w > h and self.scale_to_width < self.scale_to_height:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
|
||||||
|
if self.flip_x:
|
||||||
|
# do a flip
|
||||||
|
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
if self.flip_y:
|
||||||
|
# do a flip
|
||||||
|
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
# randomly apply a blur up to 10% of the size of the min (width, height)
|
||||||
|
min_size = min(img.width, img.height)
|
||||||
|
blur_radius = int(min_size * random.random() * 0.1)
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
||||||
|
|
||||||
|
# make grayscale
|
||||||
|
img = img.convert('L')
|
||||||
|
|
||||||
|
if self.dataset_config.buckets:
|
||||||
|
# scale and crop based on file item
|
||||||
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||||
|
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||||
|
# crop
|
||||||
|
img = img.crop((
|
||||||
|
self.crop_x,
|
||||||
|
self.crop_y,
|
||||||
|
self.crop_x + self.crop_width,
|
||||||
|
self.crop_y + self.crop_height
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
raise Exception("Mask images not supported for non-bucket datasets")
|
||||||
|
|
||||||
|
self.mask_tensor = transforms.ToTensor()(img)
|
||||||
|
# convert to grayscale
|
||||||
|
|
||||||
|
def cleanup_mask(self: 'FileItemDTO'):
|
||||||
|
self.mask_tensor = None
|
||||||
|
|
||||||
class PoiFileItemDTOMixin:
|
class PoiFileItemDTOMixin:
|
||||||
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
||||||
# items in the poi will always be inside the image when random cropping
|
# items in the poi will always be inside the image when random cropping
|
||||||
|
|||||||
Reference in New Issue
Block a user