mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Allow loading auxillery images from dataloader
This commit is contained in:
@@ -231,10 +231,10 @@ class DatasetConfig:
|
|||||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
|
||||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||||
self.augments: List[str] = kwargs.get('augments', [])
|
self.augments: List[str] = kwargs.get('augments', [])
|
||||||
|
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||||
|
|
||||||
# cache latents will store them in memory
|
# cache latents will store them in memory
|
||||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from PIL import Image
|
|||||||
from PIL.ImageOps import exif_transpose
|
from PIL.ImageOps import exif_transpose
|
||||||
|
|
||||||
from toolkit import image_utils
|
from toolkit import image_utils
|
||||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin
|
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||||
|
ControlFileItemDTOMixin, ArgBreakMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
@@ -21,9 +22,15 @@ def print_once(msg):
|
|||||||
printed_messages.append(msg)
|
printed_messages.append(msg)
|
||||||
|
|
||||||
|
|
||||||
class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
|
class FileItemDTO(
|
||||||
def __init__(self, **kwargs):
|
LatentCachingFileItemDTOMixin,
|
||||||
super().__init__()
|
CaptionProcessingDTOMixin,
|
||||||
|
ImageProcessingDTOMixin,
|
||||||
|
ControlFileItemDTOMixin,
|
||||||
|
ArgBreakMixin,
|
||||||
|
):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
self.path = kwargs.get('path', None)
|
self.path = kwargs.get('path', None)
|
||||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
# process width and height
|
# process width and height
|
||||||
@@ -58,6 +65,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.tensor = None
|
self.tensor = None
|
||||||
self.cleanup_latent()
|
self.cleanup_latent()
|
||||||
|
self.cleanup_control()
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderBatchDTO:
|
class DataLoaderBatchDTO:
|
||||||
@@ -73,6 +81,9 @@ class DataLoaderBatchDTO:
|
|||||||
self.latents: Union[torch.Tensor, None] = None
|
self.latents: Union[torch.Tensor, None] = None
|
||||||
if is_latents_cached:
|
if is_latents_cached:
|
||||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||||
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
|
if self.file_items[0].control_tensor is not None:
|
||||||
|
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
||||||
|
|
||||||
def get_is_reg_list(self):
|
def get_is_reg_list(self):
|
||||||
return [x.is_reg for x in self.file_items]
|
return [x.is_reg for x in self.file_items]
|
||||||
@@ -95,5 +106,6 @@ class DataLoaderBatchDTO:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
del self.latents
|
del self.latents
|
||||||
del self.tensor
|
del self.tensor
|
||||||
|
del self.control_tensor
|
||||||
for file_item in self.file_items:
|
for file_item in self.file_items:
|
||||||
file_item.cleanup()
|
file_item.cleanup()
|
||||||
|
|||||||
@@ -121,7 +121,8 @@ class BucketsMixin:
|
|||||||
width = file_item.crop_width
|
width = file_item.crop_width
|
||||||
height = file_item.crop_height
|
height = file_item.crop_height
|
||||||
|
|
||||||
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, divisibility=bucket_tolerance)
|
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution,
|
||||||
|
divisibility=bucket_tolerance)
|
||||||
|
|
||||||
# set the scaling height and with to match smallest size, and keep aspect ratio
|
# set the scaling height and with to match smallest size, and keep aspect ratio
|
||||||
if width > height:
|
if width > height:
|
||||||
@@ -239,6 +240,8 @@ class ImageProcessingDTOMixin:
|
|||||||
# if we are caching latents, just do that
|
# if we are caching latents, just do that
|
||||||
if self.is_latent_cached:
|
if self.is_latent_cached:
|
||||||
self.get_latent()
|
self.get_latent()
|
||||||
|
if self.has_control_image:
|
||||||
|
self.load_control_image()
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
img = Image.open(self.path).convert('RGB')
|
img = Image.open(self.path).convert('RGB')
|
||||||
@@ -302,13 +305,79 @@ class ImageProcessingDTOMixin:
|
|||||||
img = transform(img)
|
img = transform(img)
|
||||||
|
|
||||||
self.tensor = img
|
self.tensor = img
|
||||||
|
if self.has_control_image:
|
||||||
|
self.load_control_image()
|
||||||
|
|
||||||
|
|
||||||
|
class ControlFileItemDTOMixin:
|
||||||
|
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||||
|
if hasattr(super(), '__init__'):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.has_control_image = False
|
||||||
|
self.control_path: Union[str, None] = None
|
||||||
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
|
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
|
if dataset_config.control_path is not None:
|
||||||
|
# find the control image path
|
||||||
|
control_path = dataset_config.control_path
|
||||||
|
# we are using control images
|
||||||
|
img_path = kwargs.get('path', None)
|
||||||
|
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||||
|
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
for ext in img_ext_list:
|
||||||
|
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
|
||||||
|
self.control_path = os.path.join(control_path, file_name_no_ext + ext)
|
||||||
|
self.has_control_image = True
|
||||||
|
break
|
||||||
|
|
||||||
|
def load_control_image(self: 'FileItemDTO'):
|
||||||
|
try:
|
||||||
|
img = Image.open(self.control_path).convert('RGB')
|
||||||
|
img = exif_transpose(img)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
print(f"Error loading image: {self.control_path}")
|
||||||
|
w, h = img.size
|
||||||
|
if w > h and self.scale_to_width < self.scale_to_height:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
|
||||||
|
if self.flip_x:
|
||||||
|
# do a flip
|
||||||
|
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
if self.flip_y:
|
||||||
|
# do a flip
|
||||||
|
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
if self.dataset_config.buckets:
|
||||||
|
# scale and crop based on file item
|
||||||
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||||
|
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||||
|
else:
|
||||||
|
raise Exception("Control images not supported for non-bucket datasets")
|
||||||
|
|
||||||
|
self.control_tensor = transforms.ToTensor()(img)
|
||||||
|
|
||||||
|
def cleanup_control(self: 'FileItemDTO'):
|
||||||
|
self.control_tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class ArgBreakMixin:
|
||||||
|
# just stops super calls form hitting object
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LatentCachingFileItemDTOMixin:
|
class LatentCachingFileItemDTOMixin:
|
||||||
def __init__(self):
|
def __init__(self, *args, **kwargs):
|
||||||
# if we have super, call it
|
# if we have super, call it
|
||||||
if hasattr(super(), '__init__'):
|
if hasattr(super(), '__init__'):
|
||||||
super().__init__()
|
super().__init__(*args, **kwargs)
|
||||||
self._encoded_latent: Union[torch.Tensor, None] = None
|
self._encoded_latent: Union[torch.Tensor, None] = None
|
||||||
self._latent_path: Union[str, None] = None
|
self._latent_path: Union[str, None] = None
|
||||||
self.is_latent_cached = False
|
self.is_latent_cached = False
|
||||||
|
|||||||
Reference in New Issue
Block a user