Fixed big issue with bucketing dataloader and added random cripping to a point of interest

This commit is contained in:
Jaret Burkett
2023-10-02 18:31:08 -06:00
parent 320e109c5f
commit 579650eaf8
6 changed files with 264 additions and 72 deletions

View File

@@ -12,7 +12,7 @@ import torch
import torch.backends.cuda
from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
@@ -931,16 +931,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None

View File

@@ -2,6 +2,7 @@ import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import sys
import os
@@ -16,12 +17,14 @@ sys.path.append(SD_SCRIPTS_ROOT)
from library.model_util import load_vae
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
@@ -34,38 +37,44 @@ batch_size = 4
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
resolution=resolution,
caption_ext='txt',
caption_ext='json',
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
augments=['ColorJitter', 'RandomEqualize'],
augments=['ColorJitter'],
poi='person'
)
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
# run through an epoch ang check sizes
for batch in dataloader:
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
dataloader_iterator = iter(dataloader)
for epoch in range(args.epochs):
for batch in dataloader:
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
min_val = big_img.min()
max_val = big_img.max()
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
big_img = (big_img / 2 + 0.5).clamp(0, 1)
# convert to image
img = transforms.ToPILImage()(big_img)
# convert to image
img = transforms.ToPILImage()(big_img)
show_img(img)
show_img(img)
time.sleep(1.0)
time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
cv2.destroyAllWindows()

View File

@@ -234,7 +234,8 @@ class DatasetConfig:
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)

View File

@@ -346,6 +346,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.epoch_num = 0
self.sd = sd
@@ -426,13 +427,20 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.setup_epoch()
def setup_epoch(self):
# TODO: set this up to redo cropping and everything else
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
if self.epoch_num == 0:
# initial setup
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest
# setup buckets every epoch
self.setup_buckets(quiet=True)
self.epoch_num += 1
def __len__(self):
if self.dataset_config.buckets:
@@ -450,6 +458,9 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# for buckets we collate ourselves for now
# todo allow a scheduler to dynamically make buckets
# we collate ourselves
if len(self.batch_indices) - 1 < item:
# tried everything to solve this. No way to reset length when redoing things. Pick another index
item = random.randint(0, len(self.batch_indices) - 1)
idx_list = self.batch_indices[item]
return [self._get_single_item(idx) for idx in idx_list]
else:
@@ -523,3 +534,27 @@ def get_dataloader_from_datasets(
collate_fn=dto_collation
)
return data_loader
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
# hacky but needed because of different types of datasets and dataloaders
dataloader.len = None
if isinstance(dataloader.dataset, list):
for dataset in dataloader.dataset:
if hasattr(dataset, 'datasets'):
for sub_dataset in dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None
elif hasattr(dataset, 'setup_epoch'):
dataset.setup_epoch()
dataset.len = None
elif hasattr(dataloader.dataset, 'setup_epoch'):
dataloader.dataset.setup_epoch()
dataloader.dataset.len = None
elif hasattr(dataloader.dataset, 'datasets'):
dataloader.dataset.len = None
for sub_dataset in dataloader.dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -27,6 +27,7 @@ class FileItemDTO(
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
@@ -70,20 +71,25 @@ class FileItemDTO(
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]

View File

@@ -29,11 +29,13 @@ if TYPE_CHECKING:
transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
'RandomEqualize': transforms.RandomEqualize(p=0.2),
}
caption_ext_list = ['txt', 'json', 'caption']
class CaptionMixin:
def get_caption_item(self: 'AiToolkitDataset', index):
if not hasattr(self, 'caption_type'):
@@ -106,66 +108,96 @@ class BucketsMixin:
self.batch_indices: List[List[int]] = []
def build_batch_indices(self: 'AiToolkitDataset'):
self.batch_indices = []
for key, bucket in self.buckets.items():
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
batch = bucket.file_list_idx[start_idx:end_idx]
self.batch_indices.append(batch)
def setup_buckets(self: 'AiToolkitDataset'):
def setup_buckets(self: 'AiToolkitDataset', quiet=False):
if not hasattr(self, 'file_list'):
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
if not hasattr(self, 'dataset_config'):
raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
if self.epoch_num > 0 and self.dataset_config.poi is None:
# no need to rebuild buckets for now
# todo handle random cropping for buckets
return
self.buckets = {} # clear it
config: 'DatasetConfig' = self.dataset_config
resolution = config.resolution
bucket_tolerance = config.bucket_tolerance
file_list: List['FileItemDTO'] = self.file_list
total_pixels = resolution * resolution
# for file_item in enumerate(file_list):
for idx, file_item in enumerate(file_list):
file_item: 'FileItemDTO' = file_item
width = file_item.crop_width
height = file_item.crop_height
width = int(file_item.width * file_item.dataset_config.scale)
height = int(file_item.height * file_item.dataset_config.scale)
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution,
divisibility=bucket_tolerance)
# set the scaling height and with to match smallest size, and keep aspect ratio
if width > height:
file_item.scale_to_height = bucket_resolution["height"]
file_item.scale_to_width = int(width * (bucket_resolution["height"] / height))
if file_item.has_point_of_interest:
# let the poi module handle the bucketing
file_item.setup_poi_bucket()
else:
file_item.scale_to_width = bucket_resolution["width"]
file_item.scale_to_height = int(height * (bucket_resolution["width"] / width))
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
divisibility=bucket_tolerance
)
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
# Calculate scale factors for width and height
width_scale_factor = bucket_resolution["width"] / width
height_scale_factor = bucket_resolution["height"] / height
new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
file_item.scale_to_width = int(width * max_scale_factor)
file_item.scale_to_height = int(height * max_scale_factor)
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]
if self.dataset_config.random_crop:
# random crop
crop_x = random.randint(0, file_item.scale_to_width - new_width)
crop_y = random.randint(0, file_item.scale_to_height - new_height)
file_item.crop_x = crop_x
file_item.crop_y = crop_y
else:
# do central crop
file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
if file_item.crop_y < 0 or file_item.crop_x < 0:
print('debug')
# check if bucket exists, if not, create it
bucket_key = f'{new_width}x{new_height}'
bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
if bucket_key not in self.buckets:
self.buckets[bucket_key] = Bucket(new_width, new_height)
self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height)
self.buckets[bucket_key].file_list_idx.append(idx)
# print the buckets
self.build_batch_indices()
name = f"{os.path.basename(self.dataset_path)} ({self.resolution})"
print(f'Bucket sizes for {self.dataset_path}:')
for key, bucket in self.buckets.items():
print(f'{key}: {len(bucket.file_list_idx)} files')
print(f'{len(self.buckets)} buckets made')
if not quiet:
print(f'Bucket sizes for {self.dataset_path}:')
for key, bucket in self.buckets.items():
print(f'{key}: {len(bucket.file_list_idx)} files')
print(f'{len(self.buckets)} buckets made')
# file buckets made
class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
@@ -281,10 +313,19 @@ class ImageProcessingDTOMixin:
img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
print('size mismatch')
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
else:
# Downscale the source image first
# TODO this is nto right
@@ -371,7 +412,14 @@ class ControlFileItemDTOMixin:
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
@@ -381,6 +429,93 @@ class ControlFileItemDTOMixin:
self.control_tensor = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
# poi is a name of the box point of interest in the caption json file
dataset_config = kwargs.get('dataset_config', None)
path = kwargs.get('path', None)
self.poi: Union[str, None] = dataset_config.poi
self.has_point_of_interest = self.poi is not None
self.poi_x: Union[int, None] = None
self.poi_y: Union[int, None] = None
self.poi_width: Union[int, None] = None
self.poi_height: Union[int, None] = None
if self.poi is not None:
# make sure latent caching is off
if dataset_config.cache_latents or dataset_config.cache_latents_to_disk:
raise Exception(
f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config"
)
# make sure we are loading through json
if dataset_config.caption_ext != 'json':
raise Exception(
f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config"
)
self.poi = self.poi.strip()
# get the caption path
file_path_no_ext = os.path.splitext(path)[0]
caption_path = file_path_no_ext + '.json'
if not os.path.exists(caption_path):
raise Exception(f"Error: caption file not found for poi: {caption_path}")
with open(caption_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
if 'poi' not in json_data:
raise Exception(f"Error: poi not found in caption file: {caption_path}")
if self.poi not in json_data['poi']:
raise Exception(f"Error: poi not found in caption file: {caption_path}")
# poi has, x, y, width, height
poi = json_data['poi'][self.poi]
self.poi_x = int(poi['x'])
self.poi_y = int(poi['y'])
self.poi_width = int(poi['width'])
self.poi_height = int(poi['height'])
def setup_poi_bucket(self: 'FileItemDTO'):
# we are using poi, so we need to calculate the bucket based on the poi
resolution = self.dataset_config.resolution
bucket_tolerance = self.dataset_config.bucket_tolerance
initial_width = int(self.width * self.dataset_config.scale)
initial_height = int(self.height * self.dataset_config.scale)
poi_x = int(self.poi_x * self.dataset_config.scale)
poi_y = int(self.poi_y * self.dataset_config.scale)
poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale)
# todo handle a poi that is smaller than resolution
# determine new cropping
crop_left = random.randint(0, poi_x)
crop_right = random.randint(poi_x + poi_width, initial_width)
crop_top = random.randint(0, poi_y)
crop_bottom = random.randint(poi_y + poi_height, initial_height)
new_width = crop_right - crop_left
new_height = crop_bottom - crop_top
bucket_resolution = get_bucket_for_image_size(
new_width, new_height,
resolution=resolution,
divisibility=bucket_tolerance
)
width_scale_factor = bucket_resolution["width"] / new_width
height_scale_factor = bucket_resolution["height"] / new_height
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
self.scale_to_width = int(initial_width * max_scale_factor)
self.scale_to_height = int(initial_height * max_scale_factor)
self.crop_width = bucket_resolution['width']
self.crop_height = bucket_resolution['height']
self.crop_x = int(crop_left * max_scale_factor)
self.crop_y = int(crop_top * max_scale_factor)
class ArgBreakMixin:
# just stops super calls form hitting object
def __init__(self, *args, **kwargs):